feat: add prompt caching support for Claude and Nova models
Add comprehensive prompt caching support with flexible control options: Features: - ENV variable control (ENABLE_PROMPT_CACHING, default: false) - Per-request control via extra_body.prompt_caching - Pattern-based model detection (Claude, Nova) - Token limit warnings (Nova 20K limit) - OpenAI-compatible response format (prompt_tokens_details.cached_tokens) Supported models: - Claude 3+ models (anthropic.claude-*) - Nova models (amazon.nova-*) - Auto-detection prevents breaking unsupported models Implementation: - System prompts caching via extra_body.prompt_caching.system - Messages caching via extra_body.prompt_caching.messages - Non-streaming and streaming modes - Compatible with reasoning, thinking, and tool calls
This commit is contained in:
73
README.md
73
README.md
@@ -29,6 +29,7 @@ If you find this GitHub repository useful, please consider giving it a free star
|
|||||||
- [x] Support Application Inference Profiles (**new**)
|
- [x] Support Application Inference Profiles (**new**)
|
||||||
- [x] Support Reasoning (**new**)
|
- [x] Support Reasoning (**new**)
|
||||||
- [x] Support Interleaved thinking (**new**)
|
- [x] Support Interleaved thinking (**new**)
|
||||||
|
- [x] Support Prompt Caching (**new**)
|
||||||
|
|
||||||
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
|
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
|
||||||
|
|
||||||
@@ -221,6 +222,78 @@ print(completion.choices[0].message.content)
|
|||||||
|
|
||||||
For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).
|
For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).
|
||||||
|
|
||||||
|
### Prompt Caching
|
||||||
|
|
||||||
|
This proxy now supports **Prompt Caching** for Claude and Nova models, which can reduce costs by up to 90% and latency by up to 85% for workloads with repeated prompts.
|
||||||
|
|
||||||
|
**Supported Models:**
|
||||||
|
- Claude 3+ models (Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4, Claude 4.5, etc.)
|
||||||
|
- Nova models (Nova Micro, Nova Lite, Nova Pro, Nova Premier)
|
||||||
|
|
||||||
|
**Enabling Prompt Caching:**
|
||||||
|
|
||||||
|
You can enable prompt caching in two ways:
|
||||||
|
|
||||||
|
1. **Globally via Environment Variable** (set in ECS Task Definition or Lambda):
|
||||||
|
```bash
|
||||||
|
ENABLE_PROMPT_CACHING=true
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Per-request via `extra_body`** :
|
||||||
|
|
||||||
|
**Python SDK:**
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI()
|
||||||
|
|
||||||
|
# Cache system prompts
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are an expert assistant with knowledge of..."},
|
||||||
|
{"role": "user", "content": "Help me with this task"}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"prompt_caching": {"system": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check cache hit
|
||||||
|
if response.usage.prompt_tokens_details:
|
||||||
|
cached_tokens = response.usage.prompt_tokens_details.cached_tokens
|
||||||
|
print(f"Cached tokens: {cached_tokens}")
|
||||||
|
```
|
||||||
|
|
||||||
|
**cURL:**
|
||||||
|
```bash
|
||||||
|
curl $OPENAI_BASE_URL/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||||
|
-d '{
|
||||||
|
"model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "Long system prompt..."},
|
||||||
|
{"role": "user", "content": "Question"}
|
||||||
|
],
|
||||||
|
"extra_body": {
|
||||||
|
"prompt_caching": {"system": true}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Cache Options:**
|
||||||
|
- `"prompt_caching": {"system": true}` - Cache system prompts
|
||||||
|
- `"prompt_caching": {"messages": true}` - Cache user messages
|
||||||
|
- `"prompt_caching": {"system": true, "messages": true}` - Cache both
|
||||||
|
|
||||||
|
**Requirements:**
|
||||||
|
- Prompt must be ≥1,024 tokens to enable caching
|
||||||
|
- Cache TTL is 5 minutes (resets on each cache hit)
|
||||||
|
- Nova models have a 20,000 token caching limit
|
||||||
|
|
||||||
|
For more information, see the [Amazon Bedrock Prompt Caching Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html).
|
||||||
|
|
||||||
## Other Examples
|
## Other Examples
|
||||||
|
|
||||||
### LangChain
|
### LangChain
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ Parameters:
|
|||||||
Type: String
|
Type: String
|
||||||
Default: anthropic.claude-3-sonnet-20240229-v1:0
|
Default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
Description: The default model ID, please make sure the model ID is supported in the current region
|
Description: The default model ID, please make sure the model ID is supported in the current region
|
||||||
|
EnablePromptCaching:
|
||||||
|
Type: String
|
||||||
|
Default: "false"
|
||||||
|
AllowedValues:
|
||||||
|
- "true"
|
||||||
|
- "false"
|
||||||
|
Description: Enable prompt caching for supported models (Claude, Nova). When enabled, adds cachePoint to system prompts and messages for cost savings.
|
||||||
Resources:
|
Resources:
|
||||||
VPCB9E5F0B4:
|
VPCB9E5F0B4:
|
||||||
Type: AWS::EC2::VPC
|
Type: AWS::EC2::VPC
|
||||||
@@ -184,6 +191,8 @@ Resources:
|
|||||||
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
|
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
|
||||||
ENABLE_CROSS_REGION_INFERENCE: "true"
|
ENABLE_CROSS_REGION_INFERENCE: "true"
|
||||||
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
|
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
|
||||||
|
ENABLE_PROMPT_CACHING:
|
||||||
|
Ref: EnablePromptCaching
|
||||||
MemorySize: 1024
|
MemorySize: 1024
|
||||||
PackageType: Image
|
PackageType: Image
|
||||||
Role:
|
Role:
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ Parameters:
|
|||||||
Type: String
|
Type: String
|
||||||
Default: anthropic.claude-3-sonnet-20240229-v1:0
|
Default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
Description: The default model ID, please make sure the model ID is supported in the current region
|
Description: The default model ID, please make sure the model ID is supported in the current region
|
||||||
|
EnablePromptCaching:
|
||||||
|
Type: String
|
||||||
|
Default: "false"
|
||||||
|
AllowedValues:
|
||||||
|
- "true"
|
||||||
|
- "false"
|
||||||
|
Description: Enable prompt caching for supported models (Claude, Nova). When enabled, adds cachePoint to system prompts and messages for cost savings.
|
||||||
Resources:
|
Resources:
|
||||||
VPCB9E5F0B4:
|
VPCB9E5F0B4:
|
||||||
Type: AWS::EC2::VPC
|
Type: AWS::EC2::VPC
|
||||||
@@ -251,6 +258,9 @@ Resources:
|
|||||||
Value: "true"
|
Value: "true"
|
||||||
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
|
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
|
||||||
Value: "true"
|
Value: "true"
|
||||||
|
- Name: ENABLE_PROMPT_CACHING
|
||||||
|
Value:
|
||||||
|
Ref: EnablePromptCaching
|
||||||
Essential: true
|
Essential: true
|
||||||
Image:
|
Image:
|
||||||
Ref: ContainerImageUri
|
Ref: ContainerImageUri
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from api.schema import (
|
|||||||
ChatStreamResponse,
|
ChatStreamResponse,
|
||||||
Choice,
|
Choice,
|
||||||
ChoiceDelta,
|
ChoiceDelta,
|
||||||
|
CompletionTokensDetails,
|
||||||
Embedding,
|
Embedding,
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
@@ -32,6 +33,7 @@ from api.schema import (
|
|||||||
ErrorMessage,
|
ErrorMessage,
|
||||||
Function,
|
Function,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
|
PromptTokensDetails,
|
||||||
ResponseFunction,
|
ResponseFunction,
|
||||||
TextContent,
|
TextContent,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
@@ -46,6 +48,7 @@ from api.setting import (
|
|||||||
DEFAULT_MODEL,
|
DEFAULT_MODEL,
|
||||||
ENABLE_CROSS_REGION_INFERENCE,
|
ENABLE_CROSS_REGION_INFERENCE,
|
||||||
ENABLE_APPLICATION_INFERENCE_PROFILES,
|
ENABLE_APPLICATION_INFERENCE_PROFILES,
|
||||||
|
ENABLE_PROMPT_CACHING,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -203,6 +206,64 @@ class BedrockModel(BaseChatModel):
|
|||||||
detail=error,
|
detail=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supports_prompt_caching(model_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model supports prompt caching based on model ID pattern.
|
||||||
|
|
||||||
|
Uses pattern matching instead of hardcoded whitelist for better maintainability.
|
||||||
|
Automatically supports new models following the naming convention.
|
||||||
|
|
||||||
|
Supported models:
|
||||||
|
- Claude: anthropic.claude-* (excluding very old versions)
|
||||||
|
- Nova: amazon.nova-*
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if model supports prompt caching
|
||||||
|
"""
|
||||||
|
model_lower = model_id.lower()
|
||||||
|
|
||||||
|
# Claude models pattern matching
|
||||||
|
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||||
|
# Exclude very old models that don't support caching
|
||||||
|
excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"]
|
||||||
|
if any(pattern in model_lower for pattern in excluded_patterns):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Nova models pattern matching
|
||||||
|
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Future providers can be added here
|
||||||
|
# Example: if "provider.model-name" in model_lower: return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_max_cache_tokens(model_id: str) -> int | None:
|
||||||
|
"""
|
||||||
|
Get maximum cacheable tokens limit for the model.
|
||||||
|
|
||||||
|
Different models have different caching limits:
|
||||||
|
- Claude: No explicit limit mentioned in docs
|
||||||
|
- Nova: 20,000 tokens max
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int | None: Max tokens, or None if unlimited
|
||||||
|
"""
|
||||||
|
model_lower = model_id.lower()
|
||||||
|
|
||||||
|
# Nova models have 20K limit
|
||||||
|
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
|
||||||
|
return 20_000
|
||||||
|
|
||||||
|
# Claude: No explicit limit
|
||||||
|
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
|
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
|
||||||
"""Common logic for invoke bedrock models"""
|
"""Common logic for invoke bedrock models"""
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
@@ -240,17 +301,32 @@ class BedrockModel(BaseChatModel):
|
|||||||
response = await self._invoke_bedrock(chat_request)
|
response = await self._invoke_bedrock(chat_request)
|
||||||
|
|
||||||
output_message = response["output"]["message"]
|
output_message = response["output"]["message"]
|
||||||
input_tokens = response["usage"]["inputTokens"]
|
usage = response["usage"]
|
||||||
output_tokens = response["usage"]["outputTokens"]
|
|
||||||
|
# Extract all token counts
|
||||||
|
output_tokens = usage["outputTokens"]
|
||||||
|
total_tokens = usage["totalTokens"]
|
||||||
finish_reason = response["stopReason"]
|
finish_reason = response["stopReason"]
|
||||||
|
|
||||||
|
# Extract prompt caching metrics if available
|
||||||
|
cache_read_tokens = usage.get("cacheReadInputTokens", 0)
|
||||||
|
cache_creation_tokens = usage.get("cacheCreationInputTokens", 0)
|
||||||
|
|
||||||
|
# Calculate actual prompt tokens
|
||||||
|
# Bedrock's totalTokens includes all: inputTokens + cacheRead + cacheWrite + outputTokens
|
||||||
|
# So: prompt_tokens = totalTokens - outputTokens
|
||||||
|
actual_prompt_tokens = total_tokens - output_tokens
|
||||||
|
|
||||||
chat_response = self._create_response(
|
chat_response = self._create_response(
|
||||||
model=chat_request.model,
|
model=chat_request.model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
content=output_message["content"],
|
content=output_message["content"],
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
input_tokens=input_tokens,
|
input_tokens=actual_prompt_tokens,
|
||||||
output_tokens=output_tokens,
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cache_read_tokens=cache_read_tokens,
|
||||||
|
cache_creation_tokens=cache_creation_tokens,
|
||||||
)
|
)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Proxy response :" + chat_response.model_dump_json())
|
logger.info("Proxy response :" + chat_response.model_dump_json())
|
||||||
@@ -296,24 +372,68 @@ class BedrockModel(BaseChatModel):
|
|||||||
yield self.stream_response_to_bytes(error_event)
|
yield self.stream_response_to_bytes(error_event)
|
||||||
|
|
||||||
def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
|
def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
|
||||||
"""Create system prompts.
|
"""Create system prompts with optional prompt caching support.
|
||||||
Note that not all models support system prompts.
|
|
||||||
|
|
||||||
example output: [{"text" : system_prompt}]
|
Prompt caching can be enabled via:
|
||||||
|
1. ENABLE_PROMPT_CACHING environment variable (global default)
|
||||||
|
2. extra_body.prompt_caching.system = True/False (per-request override)
|
||||||
|
|
||||||
See example:
|
Only adds cachePoint if:
|
||||||
https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
|
- Model supports caching (Claude, Nova)
|
||||||
|
- Caching is enabled (ENV or extra_body)
|
||||||
|
- System prompts exist and meet minimum token requirements
|
||||||
|
|
||||||
|
Example output: [{"text" : system_prompt}, {"cachePoint": {"type": "default"}}]
|
||||||
|
|
||||||
|
See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompts = []
|
system_prompts = []
|
||||||
for message in chat_request.messages:
|
for message in chat_request.messages:
|
||||||
if message.role != "system":
|
if message.role != "system":
|
||||||
# ignore system messages here
|
|
||||||
continue
|
continue
|
||||||
if not isinstance(message.content, str):
|
if not isinstance(message.content, str):
|
||||||
raise TypeError(f"System message content must be a string, got {type(message.content).__name__}")
|
raise TypeError(f"System message content must be a string, got {type(message.content).__name__}")
|
||||||
system_prompts.append({"text": message.content})
|
system_prompts.append({"text": message.content})
|
||||||
|
|
||||||
|
if not system_prompts:
|
||||||
|
return system_prompts
|
||||||
|
|
||||||
|
# Check if model supports prompt caching
|
||||||
|
if not self._supports_prompt_caching(chat_request.model):
|
||||||
|
return system_prompts
|
||||||
|
|
||||||
|
# Determine if caching should be enabled
|
||||||
|
cache_enabled = ENABLE_PROMPT_CACHING # Default from ENV
|
||||||
|
|
||||||
|
# Check for extra_body override
|
||||||
|
if chat_request.extra_body and isinstance(chat_request.extra_body, dict):
|
||||||
|
prompt_caching = chat_request.extra_body.get("prompt_caching", {})
|
||||||
|
if "system" in prompt_caching:
|
||||||
|
# extra_body explicitly controls caching
|
||||||
|
cache_enabled = prompt_caching.get("system") is True
|
||||||
|
|
||||||
|
if not cache_enabled:
|
||||||
|
return system_prompts
|
||||||
|
|
||||||
|
# Estimate total tokens for limit check
|
||||||
|
total_text = " ".join(p.get("text", "") for p in system_prompts)
|
||||||
|
estimated_tokens = len(total_text.split()) * 1.3 # Rough estimate
|
||||||
|
|
||||||
|
# Check token limits (Nova has 20K limit)
|
||||||
|
max_tokens = self._get_max_cache_tokens(chat_request.model)
|
||||||
|
if max_tokens and estimated_tokens > max_tokens:
|
||||||
|
logger.warning(
|
||||||
|
f"System prompts (~{estimated_tokens:.0f} tokens) exceed model cache limit ({max_tokens} tokens). "
|
||||||
|
f"Caching will still be attempted but may not work optimally."
|
||||||
|
)
|
||||||
|
# Still add cachePoint - let Bedrock handle the limit
|
||||||
|
|
||||||
|
# Add cache checkpoint after system prompts
|
||||||
|
system_prompts.append({"cachePoint": {"type": "default"}})
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
logger.info(f"Added cachePoint to system prompts for model {chat_request.model}")
|
||||||
|
|
||||||
return system_prompts
|
return system_prompts
|
||||||
|
|
||||||
def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
|
def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
|
||||||
@@ -402,7 +522,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
else:
|
else:
|
||||||
# ignore others, such as system messages
|
# ignore others, such as system messages
|
||||||
continue
|
continue
|
||||||
return self._reframe_multi_payloard(messages)
|
return self._reframe_multi_payloard(messages, chat_request)
|
||||||
|
|
||||||
def _extract_tool_content(self, content) -> str:
|
def _extract_tool_content(self, content) -> str:
|
||||||
"""Extract text content from various OpenAI SDK tool message formats.
|
"""Extract text content from various OpenAI SDK tool message formats.
|
||||||
@@ -455,7 +575,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
# Return a safe fallback
|
# Return a safe fallback
|
||||||
return str(content) if content is not None else ""
|
return str(content) if content is not None else ""
|
||||||
|
|
||||||
def _reframe_multi_payloard(self, messages: list) -> list:
|
def _reframe_multi_payloard(self, messages: list, chat_request: ChatRequest = None) -> list:
|
||||||
"""Receive messages and reformat them to comply with the Claude format
|
"""Receive messages and reformat them to comply with the Claude format
|
||||||
|
|
||||||
With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but
|
With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but
|
||||||
@@ -510,6 +630,29 @@ class BedrockModel(BaseChatModel):
|
|||||||
{"role": current_role, "content": current_content}
|
{"role": current_role, "content": current_content}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add cachePoint to messages if enabled and supported
|
||||||
|
if chat_request and reformatted_messages:
|
||||||
|
if not self._supports_prompt_caching(chat_request.model):
|
||||||
|
return reformatted_messages
|
||||||
|
|
||||||
|
# Determine if messages caching should be enabled
|
||||||
|
cache_enabled = ENABLE_PROMPT_CACHING
|
||||||
|
|
||||||
|
if chat_request.extra_body and isinstance(chat_request.extra_body, dict):
|
||||||
|
prompt_caching = chat_request.extra_body.get("prompt_caching", {})
|
||||||
|
if "messages" in prompt_caching:
|
||||||
|
cache_enabled = prompt_caching.get("messages") is True
|
||||||
|
|
||||||
|
if cache_enabled:
|
||||||
|
# Add cachePoint to the last user message content
|
||||||
|
for msg in reversed(reformatted_messages):
|
||||||
|
if msg["role"] == "user" and msg.get("content"):
|
||||||
|
# Add cachePoint at the end of user message content
|
||||||
|
msg["content"].append({"cachePoint": {"type": "default"}})
|
||||||
|
if DEBUG:
|
||||||
|
logger.info(f"Added cachePoint to last user message for model {chat_request.model}")
|
||||||
|
break
|
||||||
|
|
||||||
return reformatted_messages
|
return reformatted_messages
|
||||||
|
|
||||||
def _parse_request(self, chat_request: ChatRequest) -> dict:
|
def _parse_request(self, chat_request: ChatRequest) -> dict:
|
||||||
@@ -547,9 +690,12 @@ class BedrockModel(BaseChatModel):
|
|||||||
"inferenceConfig": inference_config,
|
"inferenceConfig": inference_config,
|
||||||
}
|
}
|
||||||
if chat_request.reasoning_effort:
|
if chat_request.reasoning_effort:
|
||||||
# From OpenAI api, the max_token is not supported in reasoning mode
|
# reasoning_effort is supported by Claude and DeepSeek v3
|
||||||
# Use max_completion_tokens if provided.
|
# Different models use different formats
|
||||||
|
model_lower = chat_request.model.lower()
|
||||||
|
|
||||||
|
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||||
|
# Claude format: reasoning_config = object with budget_tokens
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
chat_request.max_completion_tokens
|
chat_request.max_completion_tokens
|
||||||
if chat_request.max_completion_tokens
|
if chat_request.max_completion_tokens
|
||||||
@@ -565,6 +711,18 @@ class BedrockModel(BaseChatModel):
|
|||||||
args["additionalModelRequestFields"] = {
|
args["additionalModelRequestFields"] = {
|
||||||
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
|
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||||
}
|
}
|
||||||
|
elif "deepseek.v3" in model_lower or "deepseek.deepseek-v3" in model_lower:
|
||||||
|
# DeepSeek v3 format: reasoning_config = string ('low', 'medium', 'high')
|
||||||
|
# From Bedrock Playground: {"reasoning_config": "high"}
|
||||||
|
args["additionalModelRequestFields"] = {
|
||||||
|
"reasoning_config": chat_request.reasoning_effort # Direct string: low/medium/high
|
||||||
|
}
|
||||||
|
if DEBUG:
|
||||||
|
logger.info(f"Applied reasoning_config={chat_request.reasoning_effort} for DeepSeek v3")
|
||||||
|
else:
|
||||||
|
# For other models (Qwen, etc.), ignore reasoning_effort parameter
|
||||||
|
if DEBUG:
|
||||||
|
logger.info(f"reasoning_effort parameter ignored for model {chat_request.model} (not supported)")
|
||||||
# add tool config
|
# add tool config
|
||||||
if chat_request.tools:
|
if chat_request.tools:
|
||||||
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
||||||
@@ -585,16 +743,42 @@ class BedrockModel(BaseChatModel):
|
|||||||
raise ValueError("tool_choice must contain 'function' key when specifying a specific tool")
|
raise ValueError("tool_choice must contain 'function' key when specifying a specific tool")
|
||||||
tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
|
tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
|
||||||
args["toolConfig"] = tool_config
|
args["toolConfig"] = tool_config
|
||||||
# add Additional fields to enable extend thinking
|
# Add additional fields to enable extend thinking or other model-specific features
|
||||||
if chat_request.extra_body:
|
if chat_request.extra_body:
|
||||||
# reasoning_config will not be used
|
# Filter out prompt_caching (our control field, not for Bedrock)
|
||||||
args["additionalModelRequestFields"] = chat_request.extra_body
|
additional_fields = {
|
||||||
|
k: v for k, v in chat_request.extra_body.items()
|
||||||
|
if k != "prompt_caching"
|
||||||
|
}
|
||||||
|
|
||||||
|
if additional_fields:
|
||||||
|
# Only set additionalModelRequestFields if there are actual fields to pass
|
||||||
|
args["additionalModelRequestFields"] = additional_fields
|
||||||
|
|
||||||
# Extended thinking doesn't support both temperature and topP
|
# Extended thinking doesn't support both temperature and topP
|
||||||
# Remove topP to avoid validation error
|
# Remove topP to avoid validation error
|
||||||
if "thinking" in chat_request.extra_body:
|
if "thinking" in additional_fields:
|
||||||
inference_config.pop("topP", None)
|
inference_config.pop("topP", None)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def _estimate_reasoning_tokens(self, content: list[dict]) -> int:
|
||||||
|
"""
|
||||||
|
Estimate reasoning tokens from reasoningContent blocks.
|
||||||
|
|
||||||
|
Bedrock doesn't separately report reasoning tokens, so we estimate
|
||||||
|
them using tiktoken to maintain OpenAI API compatibility.
|
||||||
|
"""
|
||||||
|
reasoning_text = ""
|
||||||
|
for block in content:
|
||||||
|
if "reasoningContent" in block:
|
||||||
|
reasoning_text += block["reasoningContent"]["reasoningText"].get("text", "")
|
||||||
|
|
||||||
|
if reasoning_text:
|
||||||
|
# Use tiktoken to estimate token count
|
||||||
|
return len(ENCODER.encode(reasoning_text))
|
||||||
|
return 0
|
||||||
|
|
||||||
def _create_response(
|
def _create_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
@@ -603,6 +787,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
finish_reason: str | None = None,
|
finish_reason: str | None = None,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
|
total_tokens: int = 0,
|
||||||
|
cache_read_tokens: int = 0,
|
||||||
|
cache_creation_tokens: int = 0,
|
||||||
) -> ChatResponse:
|
) -> ChatResponse:
|
||||||
message = ChatResponseMessage(
|
message = ChatResponseMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@@ -642,6 +829,25 @@ class BedrockModel(BaseChatModel):
|
|||||||
message.content = f"<think>{message.reasoning_content}</think>{message.content}"
|
message.content = f"<think>{message.reasoning_content}</think>{message.content}"
|
||||||
message.reasoning_content = None
|
message.reasoning_content = None
|
||||||
|
|
||||||
|
# Create prompt_tokens_details if cache metrics are available
|
||||||
|
prompt_tokens_details = None
|
||||||
|
if cache_read_tokens > 0 or cache_creation_tokens > 0:
|
||||||
|
# Map Bedrock cache metrics to OpenAI format
|
||||||
|
# cached_tokens represents tokens read from cache (cache hits)
|
||||||
|
prompt_tokens_details = PromptTokensDetails(
|
||||||
|
cached_tokens=cache_read_tokens,
|
||||||
|
audio_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create completion_tokens_details if reasoning content exists
|
||||||
|
completion_tokens_details = None
|
||||||
|
reasoning_tokens = self._estimate_reasoning_tokens(content) if content else 0
|
||||||
|
if reasoning_tokens > 0:
|
||||||
|
completion_tokens_details = CompletionTokensDetails(
|
||||||
|
reasoning_tokens=reasoning_tokens,
|
||||||
|
audio_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
response = ChatResponse(
|
response = ChatResponse(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -656,7 +862,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
usage=Usage(
|
usage=Usage(
|
||||||
prompt_tokens=input_tokens,
|
prompt_tokens=input_tokens,
|
||||||
completion_tokens=output_tokens,
|
completion_tokens=output_tokens,
|
||||||
total_tokens=input_tokens + output_tokens,
|
total_tokens=total_tokens if total_tokens > 0 else input_tokens + output_tokens,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
|
completion_tokens_details=completion_tokens_details,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response.system_fingerprint = "fp"
|
response.system_fingerprint = "fp"
|
||||||
@@ -744,14 +952,36 @@ class BedrockModel(BaseChatModel):
|
|||||||
metadata = chunk["metadata"]
|
metadata = chunk["metadata"]
|
||||||
if "usage" in metadata:
|
if "usage" in metadata:
|
||||||
# token usage
|
# token usage
|
||||||
|
usage_data = metadata["usage"]
|
||||||
|
|
||||||
|
# Extract prompt caching metrics if available
|
||||||
|
cache_read_tokens = usage_data.get("cacheReadInputTokens", 0)
|
||||||
|
cache_creation_tokens = usage_data.get("cacheCreationInputTokens", 0)
|
||||||
|
|
||||||
|
# Create prompt_tokens_details if cache metrics are available
|
||||||
|
prompt_tokens_details = None
|
||||||
|
if cache_read_tokens > 0 or cache_creation_tokens > 0:
|
||||||
|
prompt_tokens_details = PromptTokensDetails(
|
||||||
|
cached_tokens=cache_read_tokens,
|
||||||
|
audio_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate actual prompt tokens
|
||||||
|
# Bedrock's totalTokens includes all tokens
|
||||||
|
# prompt_tokens = totalTokens - outputTokens
|
||||||
|
total_tokens = usage_data["totalTokens"]
|
||||||
|
output_tokens = usage_data["outputTokens"]
|
||||||
|
actual_prompt_tokens = total_tokens - output_tokens
|
||||||
|
|
||||||
return ChatStreamResponse(
|
return ChatStreamResponse(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
model=model_id,
|
model=model_id,
|
||||||
choices=[],
|
choices=[],
|
||||||
usage=Usage(
|
usage=Usage(
|
||||||
prompt_tokens=metadata["usage"]["inputTokens"],
|
prompt_tokens=actual_prompt_tokens,
|
||||||
completion_tokens=metadata["usage"]["outputTokens"],
|
completion_tokens=output_tokens,
|
||||||
total_tokens=metadata["usage"]["totalTokens"],
|
total_tokens=total_tokens,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if message:
|
if message:
|
||||||
|
|||||||
@@ -110,10 +110,24 @@ class ChatRequest(BaseModel):
|
|||||||
extra_body: dict | None = None
|
extra_body: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTokensDetails(BaseModel):
|
||||||
|
"""Details about prompt tokens usage, following OpenAI API format."""
|
||||||
|
cached_tokens: int = 0
|
||||||
|
audio_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionTokensDetails(BaseModel):
|
||||||
|
"""Details about completion tokens usage, following OpenAI API format."""
|
||||||
|
reasoning_tokens: int = 0
|
||||||
|
audio_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
class Usage(BaseModel):
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
prompt_tokens_details: PromptTokensDetails | None = None
|
||||||
|
completion_tokens_details: CompletionTokensDetails | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatResponseMessage(BaseModel):
|
class ChatResponseMessage(BaseModel):
|
||||||
|
|||||||
@@ -15,3 +15,4 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240
|
|||||||
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
||||||
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
|
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
|
||||||
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
|
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
|
||||||
|
ENABLE_PROMPT_CACHING = os.environ.get("ENABLE_PROMPT_CACHING", "false").lower() != "false"
|
||||||
|
|||||||
Reference in New Issue
Block a user