feat: add prompt caching support for Claude and Nova models

Add comprehensive prompt caching support with flexible control options:

Features:
- ENV variable control (ENABLE_PROMPT_CACHING, default: false)
- Per-request control via extra_body.prompt_caching
- Pattern-based model detection (Claude, Nova)
- Token limit warnings (Nova 20K limit)
- OpenAI-compatible response format (prompt_tokens_details.cached_tokens)

Supported models:
- Claude 3+ models (anthropic.claude-*)
- Nova models (amazon.nova-*)
- Auto-detection prevents breaking unsupported models

Implementation:
- System prompts caching via extra_body.prompt_caching.system
- Messages caching via extra_body.prompt_caching.messages
- Non-streaming and streaming modes
- Compatible with reasoning, thinking, and tool calls
This commit is contained in:
Kane Zhu
2025-10-11 14:08:22 +08:00
parent 7756532b4c
commit b4800c54a0
6 changed files with 376 additions and 39 deletions

View File

@@ -24,6 +24,7 @@ from api.schema import (
ChatStreamResponse,
Choice,
ChoiceDelta,
CompletionTokensDetails,
Embedding,
EmbeddingsRequest,
EmbeddingsResponse,
@@ -32,6 +33,7 @@ from api.schema import (
ErrorMessage,
Function,
ImageContent,
PromptTokensDetails,
ResponseFunction,
TextContent,
ToolCall,
@@ -46,6 +48,7 @@ from api.setting import (
DEFAULT_MODEL,
ENABLE_CROSS_REGION_INFERENCE,
ENABLE_APPLICATION_INFERENCE_PROFILES,
ENABLE_PROMPT_CACHING,
)
logger = logging.getLogger(__name__)
@@ -203,6 +206,64 @@ class BedrockModel(BaseChatModel):
detail=error,
)
@staticmethod
def _supports_prompt_caching(model_id: str) -> bool:
"""
Check if model supports prompt caching based on model ID pattern.
Uses pattern matching instead of hardcoded whitelist for better maintainability.
Automatically supports new models following the naming convention.
Supported models:
- Claude: anthropic.claude-* (excluding very old versions)
- Nova: amazon.nova-*
Returns:
bool: True if model supports prompt caching
"""
model_lower = model_id.lower()
# Claude models pattern matching
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
# Exclude very old models that don't support caching
excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"]
if any(pattern in model_lower for pattern in excluded_patterns):
return False
return True
# Nova models pattern matching
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
return True
# Future providers can be added here
# Example: if "provider.model-name" in model_lower: return True
return False
@staticmethod
def _get_max_cache_tokens(model_id: str) -> int | None:
"""
Get maximum cacheable tokens limit for the model.
Different models have different caching limits:
- Claude: No explicit limit mentioned in docs
- Nova: 20,000 tokens max
Returns:
int | None: Max tokens, or None if unlimited
"""
model_lower = model_id.lower()
# Nova models have 20K limit
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
return 20_000
# Claude: No explicit limit
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
return None
return None
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
"""Common logic for invoke bedrock models"""
if DEBUG:
@@ -240,17 +301,32 @@ class BedrockModel(BaseChatModel):
response = await self._invoke_bedrock(chat_request)
output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
output_tokens = response["usage"]["outputTokens"]
usage = response["usage"]
# Extract all token counts
output_tokens = usage["outputTokens"]
total_tokens = usage["totalTokens"]
finish_reason = response["stopReason"]
# Extract prompt caching metrics if available
cache_read_tokens = usage.get("cacheReadInputTokens", 0)
cache_creation_tokens = usage.get("cacheCreationInputTokens", 0)
# Calculate actual prompt tokens
# Bedrock's totalTokens includes all: inputTokens + cacheRead + cacheWrite + outputTokens
# So: prompt_tokens = totalTokens - outputTokens
actual_prompt_tokens = total_tokens - output_tokens
chat_response = self._create_response(
model=chat_request.model,
message_id=message_id,
content=output_message["content"],
finish_reason=finish_reason,
input_tokens=input_tokens,
input_tokens=actual_prompt_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
if DEBUG:
logger.info("Proxy response :" + chat_response.model_dump_json())
@@ -296,24 +372,68 @@ class BedrockModel(BaseChatModel):
yield self.stream_response_to_bytes(error_event)
def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
"""Create system prompts.
Note that not all models support system prompts.
"""Create system prompts with optional prompt caching support.
example output: [{"text" : system_prompt}]
Prompt caching can be enabled via:
1. ENABLE_PROMPT_CACHING environment variable (global default)
2. extra_body.prompt_caching.system = True/False (per-request override)
See example:
https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
Only adds cachePoint if:
- Model supports caching (Claude, Nova)
- Caching is enabled (ENV or extra_body)
- System prompts exist and meet minimum token requirements
Example output: [{"text" : system_prompt}, {"cachePoint": {"type": "default"}}]
See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
"""
system_prompts = []
for message in chat_request.messages:
if message.role != "system":
# ignore system messages here
continue
if not isinstance(message.content, str):
raise TypeError(f"System message content must be a string, got {type(message.content).__name__}")
system_prompts.append({"text": message.content})
if not system_prompts:
return system_prompts
# Check if model supports prompt caching
if not self._supports_prompt_caching(chat_request.model):
return system_prompts
# Determine if caching should be enabled
cache_enabled = ENABLE_PROMPT_CACHING # Default from ENV
# Check for extra_body override
if chat_request.extra_body and isinstance(chat_request.extra_body, dict):
prompt_caching = chat_request.extra_body.get("prompt_caching", {})
if "system" in prompt_caching:
# extra_body explicitly controls caching
cache_enabled = prompt_caching.get("system") is True
if not cache_enabled:
return system_prompts
# Estimate total tokens for limit check
total_text = " ".join(p.get("text", "") for p in system_prompts)
estimated_tokens = len(total_text.split()) * 1.3 # Rough estimate
# Check token limits (Nova has 20K limit)
max_tokens = self._get_max_cache_tokens(chat_request.model)
if max_tokens and estimated_tokens > max_tokens:
logger.warning(
f"System prompts (~{estimated_tokens:.0f} tokens) exceed model cache limit ({max_tokens} tokens). "
f"Caching will still be attempted but may not work optimally."
)
# Still add cachePoint - let Bedrock handle the limit
# Add cache checkpoint after system prompts
system_prompts.append({"cachePoint": {"type": "default"}})
if DEBUG:
logger.info(f"Added cachePoint to system prompts for model {chat_request.model}")
return system_prompts
def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
@@ -402,7 +522,7 @@ class BedrockModel(BaseChatModel):
else:
# ignore others, such as system messages
continue
return self._reframe_multi_payloard(messages)
return self._reframe_multi_payloard(messages, chat_request)
def _extract_tool_content(self, content) -> str:
"""Extract text content from various OpenAI SDK tool message formats.
@@ -455,7 +575,7 @@ class BedrockModel(BaseChatModel):
# Return a safe fallback
return str(content) if content is not None else ""
def _reframe_multi_payloard(self, messages: list) -> list:
def _reframe_multi_payloard(self, messages: list, chat_request: ChatRequest = None) -> list:
"""Receive messages and reformat them to comply with the Claude format
With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but
@@ -510,6 +630,29 @@ class BedrockModel(BaseChatModel):
{"role": current_role, "content": current_content}
)
# Add cachePoint to messages if enabled and supported
if chat_request and reformatted_messages:
if not self._supports_prompt_caching(chat_request.model):
return reformatted_messages
# Determine if messages caching should be enabled
cache_enabled = ENABLE_PROMPT_CACHING
if chat_request.extra_body and isinstance(chat_request.extra_body, dict):
prompt_caching = chat_request.extra_body.get("prompt_caching", {})
if "messages" in prompt_caching:
cache_enabled = prompt_caching.get("messages") is True
if cache_enabled:
# Add cachePoint to the last user message content
for msg in reversed(reformatted_messages):
if msg["role"] == "user" and msg.get("content"):
# Add cachePoint at the end of user message content
msg["content"].append({"cachePoint": {"type": "default"}})
if DEBUG:
logger.info(f"Added cachePoint to last user message for model {chat_request.model}")
break
return reformatted_messages
def _parse_request(self, chat_request: ChatRequest) -> dict:
@@ -547,24 +690,39 @@ class BedrockModel(BaseChatModel):
"inferenceConfig": inference_config,
}
if chat_request.reasoning_effort:
# From OpenAI api, the max_token is not supported in reasoning mode
# Use max_completion_tokens if provided.
# reasoning_effort is supported by Claude and DeepSeek v3
# Different models use different formats
model_lower = chat_request.model.lower()
max_tokens = (
chat_request.max_completion_tokens
if chat_request.max_completion_tokens
else chat_request.max_tokens
)
budget_tokens = self._calc_budget_tokens(
max_tokens, chat_request.reasoning_effort
)
inference_config["maxTokens"] = max_tokens
# unset topP - Not supported
inference_config.pop("topP", None)
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
# Claude format: reasoning_config = object with budget_tokens
max_tokens = (
chat_request.max_completion_tokens
if chat_request.max_completion_tokens
else chat_request.max_tokens
)
budget_tokens = self._calc_budget_tokens(
max_tokens, chat_request.reasoning_effort
)
inference_config["maxTokens"] = max_tokens
# unset topP - Not supported
inference_config.pop("topP", None)
args["additionalModelRequestFields"] = {
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
}
args["additionalModelRequestFields"] = {
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
}
elif "deepseek.v3" in model_lower or "deepseek.deepseek-v3" in model_lower:
# DeepSeek v3 format: reasoning_config = string ('low', 'medium', 'high')
# From Bedrock Playground: {"reasoning_config": "high"}
args["additionalModelRequestFields"] = {
"reasoning_config": chat_request.reasoning_effort # Direct string: low/medium/high
}
if DEBUG:
logger.info(f"Applied reasoning_config={chat_request.reasoning_effort} for DeepSeek v3")
else:
# For other models (Qwen, etc.), ignore reasoning_effort parameter
if DEBUG:
logger.info(f"reasoning_effort parameter ignored for model {chat_request.model} (not supported)")
# add tool config
if chat_request.tools:
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
@@ -585,16 +743,42 @@ class BedrockModel(BaseChatModel):
raise ValueError("tool_choice must contain 'function' key when specifying a specific tool")
tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
args["toolConfig"] = tool_config
# add Additional fields to enable extend thinking
# Add additional fields to enable extend thinking or other model-specific features
if chat_request.extra_body:
# reasoning_config will not be used
args["additionalModelRequestFields"] = chat_request.extra_body
# Extended thinking doesn't support both temperature and topP
# Remove topP to avoid validation error
if "thinking" in chat_request.extra_body:
inference_config.pop("topP", None)
# Filter out prompt_caching (our control field, not for Bedrock)
additional_fields = {
k: v for k, v in chat_request.extra_body.items()
if k != "prompt_caching"
}
if additional_fields:
# Only set additionalModelRequestFields if there are actual fields to pass
args["additionalModelRequestFields"] = additional_fields
# Extended thinking doesn't support both temperature and topP
# Remove topP to avoid validation error
if "thinking" in additional_fields:
inference_config.pop("topP", None)
return args
def _estimate_reasoning_tokens(self, content: list[dict]) -> int:
"""
Estimate reasoning tokens from reasoningContent blocks.
Bedrock doesn't separately report reasoning tokens, so we estimate
them using tiktoken to maintain OpenAI API compatibility.
"""
reasoning_text = ""
for block in content:
if "reasoningContent" in block:
reasoning_text += block["reasoningContent"]["reasoningText"].get("text", "")
if reasoning_text:
# Use tiktoken to estimate token count
return len(ENCODER.encode(reasoning_text))
return 0
def _create_response(
self,
model: str,
@@ -603,6 +787,9 @@ class BedrockModel(BaseChatModel):
finish_reason: str | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
total_tokens: int = 0,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> ChatResponse:
message = ChatResponseMessage(
role="assistant",
@@ -642,6 +829,25 @@ class BedrockModel(BaseChatModel):
message.content = f"<think>{message.reasoning_content}</think>{message.content}"
message.reasoning_content = None
# Create prompt_tokens_details if cache metrics are available
prompt_tokens_details = None
if cache_read_tokens > 0 or cache_creation_tokens > 0:
# Map Bedrock cache metrics to OpenAI format
# cached_tokens represents tokens read from cache (cache hits)
prompt_tokens_details = PromptTokensDetails(
cached_tokens=cache_read_tokens,
audio_tokens=0,
)
# Create completion_tokens_details if reasoning content exists
completion_tokens_details = None
reasoning_tokens = self._estimate_reasoning_tokens(content) if content else 0
if reasoning_tokens > 0:
completion_tokens_details = CompletionTokensDetails(
reasoning_tokens=reasoning_tokens,
audio_tokens=0,
)
response = ChatResponse(
id=message_id,
model=model,
@@ -656,7 +862,9 @@ class BedrockModel(BaseChatModel):
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
total_tokens=total_tokens if total_tokens > 0 else input_tokens + output_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
),
)
response.system_fingerprint = "fp"
@@ -744,14 +952,36 @@ class BedrockModel(BaseChatModel):
metadata = chunk["metadata"]
if "usage" in metadata:
# token usage
usage_data = metadata["usage"]
# Extract prompt caching metrics if available
cache_read_tokens = usage_data.get("cacheReadInputTokens", 0)
cache_creation_tokens = usage_data.get("cacheCreationInputTokens", 0)
# Create prompt_tokens_details if cache metrics are available
prompt_tokens_details = None
if cache_read_tokens > 0 or cache_creation_tokens > 0:
prompt_tokens_details = PromptTokensDetails(
cached_tokens=cache_read_tokens,
audio_tokens=0,
)
# Calculate actual prompt tokens
# Bedrock's totalTokens includes all tokens
# prompt_tokens = totalTokens - outputTokens
total_tokens = usage_data["totalTokens"]
output_tokens = usage_data["outputTokens"]
actual_prompt_tokens = total_tokens - output_tokens
return ChatStreamResponse(
id=message_id,
model=model_id,
choices=[],
usage=Usage(
prompt_tokens=metadata["usage"]["inputTokens"],
completion_tokens=metadata["usage"]["outputTokens"],
total_tokens=metadata["usage"]["totalTokens"],
prompt_tokens=actual_prompt_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
),
)
if message:

View File

@@ -110,10 +110,24 @@ class ChatRequest(BaseModel):
extra_body: dict | None = None
class PromptTokensDetails(BaseModel):
"""Details about prompt tokens usage, following OpenAI API format."""
cached_tokens: int = 0
audio_tokens: int = 0
class CompletionTokensDetails(BaseModel):
"""Details about completion tokens usage, following OpenAI API format."""
reasoning_tokens: int = 0
audio_tokens: int = 0
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
class ChatResponseMessage(BaseModel):

View File

@@ -15,3 +15,4 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
ENABLE_PROMPT_CACHING = os.environ.get("ENABLE_PROMPT_CACHING", "false").lower() != "false"