From b4800c54a05b15b6247662e8e66c3d80124b4f84 Mon Sep 17 00:00:00 2001 From: Kane Zhu <843303+zxkane@users.noreply.github.com> Date: Sat, 11 Oct 2025 14:08:22 +0800 Subject: [PATCH] feat: add prompt caching support for Claude and Nova models Add comprehensive prompt caching support with flexible control options: Features: - ENV variable control (ENABLE_PROMPT_CACHING, default: false) - Per-request control via extra_body.prompt_caching - Pattern-based model detection (Claude, Nova) - Token limit warnings (Nova 20K limit) - OpenAI-compatible response format (prompt_tokens_details.cached_tokens) Supported models: - Claude 3+ models (anthropic.claude-*) - Nova models (amazon.nova-*) - Auto-detection prevents breaking unsupported models Implementation: - System prompts caching via extra_body.prompt_caching.system - Messages caching via extra_body.prompt_caching.messages - Non-streaming and streaming modes - Compatible with reasoning, thinking, and tool calls --- README.md | 73 ++++++ deployment/BedrockProxy.template | 9 + deployment/BedrockProxyFargate.template | 10 + src/api/models/bedrock.py | 308 +++++++++++++++++++++--- src/api/schema.py | 14 ++ src/api/setting.py | 1 + 6 files changed, 376 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 17154e0..306990e 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ If you find this GitHub repository useful, please consider giving it a free star - [x] Support Application Inference Profiles (**new**) - [x] Support Reasoning (**new**) - [x] Support Interleaved thinking (**new**) +- [x] Support Prompt Caching (**new**) Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs. @@ -221,6 +222,78 @@ print(completion.choices[0].message.content) For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html). +### Prompt Caching + +This proxy now supports **Prompt Caching** for Claude and Nova models, which can reduce costs by up to 90% and latency by up to 85% for workloads with repeated prompts. + +**Supported Models:** +- Claude 3+ models (Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4, Claude 4.5, etc.) +- Nova models (Nova Micro, Nova Lite, Nova Pro, Nova Premier) + +**Enabling Prompt Caching:** + +You can enable prompt caching in two ways: + +1. **Globally via Environment Variable** (set in ECS Task Definition or Lambda): +```bash +ENABLE_PROMPT_CACHING=true +``` + +2. **Per-request via `extra_body`** : + +**Python SDK:** +```python +from openai import OpenAI + +client = OpenAI() + +# Cache system prompts +response = client.chat.completions.create( + model="us.anthropic.claude-3-7-sonnet-20250219-v1:0", + messages=[ + {"role": "system", "content": "You are an expert assistant with knowledge of..."}, + {"role": "user", "content": "Help me with this task"} + ], + extra_body={ + "prompt_caching": {"system": True} + } +) + +# Check cache hit +if response.usage.prompt_tokens_details: + cached_tokens = response.usage.prompt_tokens_details.cached_tokens + print(f"Cached tokens: {cached_tokens}") +``` + +**cURL:** +```bash +curl $OPENAI_BASE_URL/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "messages": [ + {"role": "system", "content": "Long system prompt..."}, + {"role": "user", "content": "Question"} + ], + "extra_body": { + "prompt_caching": {"system": true} + } + }' +``` + +**Cache Options:** +- `"prompt_caching": {"system": true}` - Cache system prompts +- `"prompt_caching": {"messages": true}` - Cache user messages +- `"prompt_caching": {"system": true, "messages": true}` - Cache both + +**Requirements:** +- Prompt must be ≥1,024 tokens to enable caching +- Cache TTL is 5 minutes (resets on each cache hit) +- Nova models have a 20,000 token caching limit + +For more information, see the [Amazon Bedrock Prompt Caching Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html). + ## Other Examples ### LangChain diff --git a/deployment/BedrockProxy.template b/deployment/BedrockProxy.template index ec97dad..1b15de4 100644 --- a/deployment/BedrockProxy.template +++ b/deployment/BedrockProxy.template @@ -11,6 +11,13 @@ Parameters: Type: String Default: anthropic.claude-3-sonnet-20240229-v1:0 Description: The default model ID, please make sure the model ID is supported in the current region + EnablePromptCaching: + Type: String + Default: "false" + AllowedValues: + - "true" + - "false" + Description: Enable prompt caching for supported models (Claude, Nova). When enabled, adds cachePoint to system prompts and messages for cost savings. Resources: VPCB9E5F0B4: Type: AWS::EC2::VPC @@ -184,6 +191,8 @@ Resources: DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3 ENABLE_CROSS_REGION_INFERENCE: "true" ENABLE_APPLICATION_INFERENCE_PROFILES: "true" + ENABLE_PROMPT_CACHING: + Ref: EnablePromptCaching MemorySize: 1024 PackageType: Image Role: diff --git a/deployment/BedrockProxyFargate.template b/deployment/BedrockProxyFargate.template index ed99267..4fee3ed 100644 --- a/deployment/BedrockProxyFargate.template +++ b/deployment/BedrockProxyFargate.template @@ -11,6 +11,13 @@ Parameters: Type: String Default: anthropic.claude-3-sonnet-20240229-v1:0 Description: The default model ID, please make sure the model ID is supported in the current region + EnablePromptCaching: + Type: String + Default: "false" + AllowedValues: + - "true" + - "false" + Description: Enable prompt caching for supported models (Claude, Nova). When enabled, adds cachePoint to system prompts and messages for cost savings. Resources: VPCB9E5F0B4: Type: AWS::EC2::VPC @@ -251,6 +258,9 @@ Resources: Value: "true" - Name: ENABLE_APPLICATION_INFERENCE_PROFILES Value: "true" + - Name: ENABLE_PROMPT_CACHING + Value: + Ref: EnablePromptCaching Essential: true Image: Ref: ContainerImageUri diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index fba048b..3effdd7 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -24,6 +24,7 @@ from api.schema import ( ChatStreamResponse, Choice, ChoiceDelta, + CompletionTokensDetails, Embedding, EmbeddingsRequest, EmbeddingsResponse, @@ -32,6 +33,7 @@ from api.schema import ( ErrorMessage, Function, ImageContent, + PromptTokensDetails, ResponseFunction, TextContent, ToolCall, @@ -46,6 +48,7 @@ from api.setting import ( DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE, ENABLE_APPLICATION_INFERENCE_PROFILES, + ENABLE_PROMPT_CACHING, ) logger = logging.getLogger(__name__) @@ -203,6 +206,64 @@ class BedrockModel(BaseChatModel): detail=error, ) + @staticmethod + def _supports_prompt_caching(model_id: str) -> bool: + """ + Check if model supports prompt caching based on model ID pattern. + + Uses pattern matching instead of hardcoded whitelist for better maintainability. + Automatically supports new models following the naming convention. + + Supported models: + - Claude: anthropic.claude-* (excluding very old versions) + - Nova: amazon.nova-* + + Returns: + bool: True if model supports prompt caching + """ + model_lower = model_id.lower() + + # Claude models pattern matching + if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower: + # Exclude very old models that don't support caching + excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"] + if any(pattern in model_lower for pattern in excluded_patterns): + return False + return True + + # Nova models pattern matching + if "amazon.nova" in model_lower or ".amazon.nova" in model_lower: + return True + + # Future providers can be added here + # Example: if "provider.model-name" in model_lower: return True + + return False + + @staticmethod + def _get_max_cache_tokens(model_id: str) -> int | None: + """ + Get maximum cacheable tokens limit for the model. + + Different models have different caching limits: + - Claude: No explicit limit mentioned in docs + - Nova: 20,000 tokens max + + Returns: + int | None: Max tokens, or None if unlimited + """ + model_lower = model_id.lower() + + # Nova models have 20K limit + if "amazon.nova" in model_lower or ".amazon.nova" in model_lower: + return 20_000 + + # Claude: No explicit limit + if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower: + return None + + return None + async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): """Common logic for invoke bedrock models""" if DEBUG: @@ -240,17 +301,32 @@ class BedrockModel(BaseChatModel): response = await self._invoke_bedrock(chat_request) output_message = response["output"]["message"] - input_tokens = response["usage"]["inputTokens"] - output_tokens = response["usage"]["outputTokens"] + usage = response["usage"] + + # Extract all token counts + output_tokens = usage["outputTokens"] + total_tokens = usage["totalTokens"] finish_reason = response["stopReason"] + # Extract prompt caching metrics if available + cache_read_tokens = usage.get("cacheReadInputTokens", 0) + cache_creation_tokens = usage.get("cacheCreationInputTokens", 0) + + # Calculate actual prompt tokens + # Bedrock's totalTokens includes all: inputTokens + cacheRead + cacheWrite + outputTokens + # So: prompt_tokens = totalTokens - outputTokens + actual_prompt_tokens = total_tokens - output_tokens + chat_response = self._create_response( model=chat_request.model, message_id=message_id, content=output_message["content"], finish_reason=finish_reason, - input_tokens=input_tokens, + input_tokens=actual_prompt_tokens, output_tokens=output_tokens, + total_tokens=total_tokens, + cache_read_tokens=cache_read_tokens, + cache_creation_tokens=cache_creation_tokens, ) if DEBUG: logger.info("Proxy response :" + chat_response.model_dump_json()) @@ -296,24 +372,68 @@ class BedrockModel(BaseChatModel): yield self.stream_response_to_bytes(error_event) def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]: - """Create system prompts. - Note that not all models support system prompts. + """Create system prompts with optional prompt caching support. - example output: [{"text" : system_prompt}] + Prompt caching can be enabled via: + 1. ENABLE_PROMPT_CACHING environment variable (global default) + 2. extra_body.prompt_caching.system = True/False (per-request override) - See example: - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + Only adds cachePoint if: + - Model supports caching (Claude, Nova) + - Caching is enabled (ENV or extra_body) + - System prompts exist and meet minimum token requirements + + Example output: [{"text" : system_prompt}, {"cachePoint": {"type": "default"}}] + + See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html """ - system_prompts = [] for message in chat_request.messages: if message.role != "system": - # ignore system messages here continue if not isinstance(message.content, str): raise TypeError(f"System message content must be a string, got {type(message.content).__name__}") system_prompts.append({"text": message.content}) + if not system_prompts: + return system_prompts + + # Check if model supports prompt caching + if not self._supports_prompt_caching(chat_request.model): + return system_prompts + + # Determine if caching should be enabled + cache_enabled = ENABLE_PROMPT_CACHING # Default from ENV + + # Check for extra_body override + if chat_request.extra_body and isinstance(chat_request.extra_body, dict): + prompt_caching = chat_request.extra_body.get("prompt_caching", {}) + if "system" in prompt_caching: + # extra_body explicitly controls caching + cache_enabled = prompt_caching.get("system") is True + + if not cache_enabled: + return system_prompts + + # Estimate total tokens for limit check + total_text = " ".join(p.get("text", "") for p in system_prompts) + estimated_tokens = len(total_text.split()) * 1.3 # Rough estimate + + # Check token limits (Nova has 20K limit) + max_tokens = self._get_max_cache_tokens(chat_request.model) + if max_tokens and estimated_tokens > max_tokens: + logger.warning( + f"System prompts (~{estimated_tokens:.0f} tokens) exceed model cache limit ({max_tokens} tokens). " + f"Caching will still be attempted but may not work optimally." + ) + # Still add cachePoint - let Bedrock handle the limit + + # Add cache checkpoint after system prompts + system_prompts.append({"cachePoint": {"type": "default"}}) + + if DEBUG: + logger.info(f"Added cachePoint to system prompts for model {chat_request.model}") + return system_prompts def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: @@ -402,7 +522,7 @@ class BedrockModel(BaseChatModel): else: # ignore others, such as system messages continue - return self._reframe_multi_payloard(messages) + return self._reframe_multi_payloard(messages, chat_request) def _extract_tool_content(self, content) -> str: """Extract text content from various OpenAI SDK tool message formats. @@ -455,7 +575,7 @@ class BedrockModel(BaseChatModel): # Return a safe fallback return str(content) if content is not None else "" - def _reframe_multi_payloard(self, messages: list) -> list: + def _reframe_multi_payloard(self, messages: list, chat_request: ChatRequest = None) -> list: """Receive messages and reformat them to comply with the Claude format With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but @@ -510,6 +630,29 @@ class BedrockModel(BaseChatModel): {"role": current_role, "content": current_content} ) + # Add cachePoint to messages if enabled and supported + if chat_request and reformatted_messages: + if not self._supports_prompt_caching(chat_request.model): + return reformatted_messages + + # Determine if messages caching should be enabled + cache_enabled = ENABLE_PROMPT_CACHING + + if chat_request.extra_body and isinstance(chat_request.extra_body, dict): + prompt_caching = chat_request.extra_body.get("prompt_caching", {}) + if "messages" in prompt_caching: + cache_enabled = prompt_caching.get("messages") is True + + if cache_enabled: + # Add cachePoint to the last user message content + for msg in reversed(reformatted_messages): + if msg["role"] == "user" and msg.get("content"): + # Add cachePoint at the end of user message content + msg["content"].append({"cachePoint": {"type": "default"}}) + if DEBUG: + logger.info(f"Added cachePoint to last user message for model {chat_request.model}") + break + return reformatted_messages def _parse_request(self, chat_request: ChatRequest) -> dict: @@ -547,24 +690,39 @@ class BedrockModel(BaseChatModel): "inferenceConfig": inference_config, } if chat_request.reasoning_effort: - # From OpenAI api, the max_token is not supported in reasoning mode - # Use max_completion_tokens if provided. + # reasoning_effort is supported by Claude and DeepSeek v3 + # Different models use different formats + model_lower = chat_request.model.lower() - max_tokens = ( - chat_request.max_completion_tokens - if chat_request.max_completion_tokens - else chat_request.max_tokens - ) - budget_tokens = self._calc_budget_tokens( - max_tokens, chat_request.reasoning_effort - ) - inference_config["maxTokens"] = max_tokens - # unset topP - Not supported - inference_config.pop("topP", None) + if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower: + # Claude format: reasoning_config = object with budget_tokens + max_tokens = ( + chat_request.max_completion_tokens + if chat_request.max_completion_tokens + else chat_request.max_tokens + ) + budget_tokens = self._calc_budget_tokens( + max_tokens, chat_request.reasoning_effort + ) + inference_config["maxTokens"] = max_tokens + # unset topP - Not supported + inference_config.pop("topP", None) - args["additionalModelRequestFields"] = { - "reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens} - } + args["additionalModelRequestFields"] = { + "reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens} + } + elif "deepseek.v3" in model_lower or "deepseek.deepseek-v3" in model_lower: + # DeepSeek v3 format: reasoning_config = string ('low', 'medium', 'high') + # From Bedrock Playground: {"reasoning_config": "high"} + args["additionalModelRequestFields"] = { + "reasoning_config": chat_request.reasoning_effort # Direct string: low/medium/high + } + if DEBUG: + logger.info(f"Applied reasoning_config={chat_request.reasoning_effort} for DeepSeek v3") + else: + # For other models (Qwen, etc.), ignore reasoning_effort parameter + if DEBUG: + logger.info(f"reasoning_effort parameter ignored for model {chat_request.model} (not supported)") # add tool config if chat_request.tools: tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]} @@ -585,16 +743,42 @@ class BedrockModel(BaseChatModel): raise ValueError("tool_choice must contain 'function' key when specifying a specific tool") tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}} args["toolConfig"] = tool_config - # add Additional fields to enable extend thinking + # Add additional fields to enable extend thinking or other model-specific features if chat_request.extra_body: - # reasoning_config will not be used - args["additionalModelRequestFields"] = chat_request.extra_body - # Extended thinking doesn't support both temperature and topP - # Remove topP to avoid validation error - if "thinking" in chat_request.extra_body: - inference_config.pop("topP", None) + # Filter out prompt_caching (our control field, not for Bedrock) + additional_fields = { + k: v for k, v in chat_request.extra_body.items() + if k != "prompt_caching" + } + + if additional_fields: + # Only set additionalModelRequestFields if there are actual fields to pass + args["additionalModelRequestFields"] = additional_fields + + # Extended thinking doesn't support both temperature and topP + # Remove topP to avoid validation error + if "thinking" in additional_fields: + inference_config.pop("topP", None) + return args + def _estimate_reasoning_tokens(self, content: list[dict]) -> int: + """ + Estimate reasoning tokens from reasoningContent blocks. + + Bedrock doesn't separately report reasoning tokens, so we estimate + them using tiktoken to maintain OpenAI API compatibility. + """ + reasoning_text = "" + for block in content: + if "reasoningContent" in block: + reasoning_text += block["reasoningContent"]["reasoningText"].get("text", "") + + if reasoning_text: + # Use tiktoken to estimate token count + return len(ENCODER.encode(reasoning_text)) + return 0 + def _create_response( self, model: str, @@ -603,6 +787,9 @@ class BedrockModel(BaseChatModel): finish_reason: str | None = None, input_tokens: int = 0, output_tokens: int = 0, + total_tokens: int = 0, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, ) -> ChatResponse: message = ChatResponseMessage( role="assistant", @@ -642,6 +829,25 @@ class BedrockModel(BaseChatModel): message.content = f"{message.reasoning_content}{message.content}" message.reasoning_content = None + # Create prompt_tokens_details if cache metrics are available + prompt_tokens_details = None + if cache_read_tokens > 0 or cache_creation_tokens > 0: + # Map Bedrock cache metrics to OpenAI format + # cached_tokens represents tokens read from cache (cache hits) + prompt_tokens_details = PromptTokensDetails( + cached_tokens=cache_read_tokens, + audio_tokens=0, + ) + + # Create completion_tokens_details if reasoning content exists + completion_tokens_details = None + reasoning_tokens = self._estimate_reasoning_tokens(content) if content else 0 + if reasoning_tokens > 0: + completion_tokens_details = CompletionTokensDetails( + reasoning_tokens=reasoning_tokens, + audio_tokens=0, + ) + response = ChatResponse( id=message_id, model=model, @@ -656,7 +862,9 @@ class BedrockModel(BaseChatModel): usage=Usage( prompt_tokens=input_tokens, completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, + total_tokens=total_tokens if total_tokens > 0 else input_tokens + output_tokens, + prompt_tokens_details=prompt_tokens_details, + completion_tokens_details=completion_tokens_details, ), ) response.system_fingerprint = "fp" @@ -744,14 +952,36 @@ class BedrockModel(BaseChatModel): metadata = chunk["metadata"] if "usage" in metadata: # token usage + usage_data = metadata["usage"] + + # Extract prompt caching metrics if available + cache_read_tokens = usage_data.get("cacheReadInputTokens", 0) + cache_creation_tokens = usage_data.get("cacheCreationInputTokens", 0) + + # Create prompt_tokens_details if cache metrics are available + prompt_tokens_details = None + if cache_read_tokens > 0 or cache_creation_tokens > 0: + prompt_tokens_details = PromptTokensDetails( + cached_tokens=cache_read_tokens, + audio_tokens=0, + ) + + # Calculate actual prompt tokens + # Bedrock's totalTokens includes all tokens + # prompt_tokens = totalTokens - outputTokens + total_tokens = usage_data["totalTokens"] + output_tokens = usage_data["outputTokens"] + actual_prompt_tokens = total_tokens - output_tokens + return ChatStreamResponse( id=message_id, model=model_id, choices=[], usage=Usage( - prompt_tokens=metadata["usage"]["inputTokens"], - completion_tokens=metadata["usage"]["outputTokens"], - total_tokens=metadata["usage"]["totalTokens"], + prompt_tokens=actual_prompt_tokens, + completion_tokens=output_tokens, + total_tokens=total_tokens, + prompt_tokens_details=prompt_tokens_details, ), ) if message: diff --git a/src/api/schema.py b/src/api/schema.py index 233e113..dcdfcd0 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -110,10 +110,24 @@ class ChatRequest(BaseModel): extra_body: dict | None = None +class PromptTokensDetails(BaseModel): + """Details about prompt tokens usage, following OpenAI API format.""" + cached_tokens: int = 0 + audio_tokens: int = 0 + + +class CompletionTokensDetails(BaseModel): + """Details about completion tokens usage, following OpenAI API format.""" + reasoning_tokens: int = 0 + audio_tokens: int = 0 + + class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int + prompt_tokens_details: PromptTokensDetails | None = None + completion_tokens_details: CompletionTokensDetails | None = None class ChatResponseMessage(BaseModel): diff --git a/src/api/setting.py b/src/api/setting.py index 43fd2b7..c69780b 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -15,3 +15,4 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240 DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false" +ENABLE_PROMPT_CACHING = os.environ.get("ENABLE_PROMPT_CACHING", "false").lower() != "false"