feat: add prompt caching support for Claude and Nova models
Add comprehensive prompt caching support with flexible control options: Features: - ENV variable control (ENABLE_PROMPT_CACHING, default: false) - Per-request control via extra_body.prompt_caching - Pattern-based model detection (Claude, Nova) - Token limit warnings (Nova 20K limit) - OpenAI-compatible response format (prompt_tokens_details.cached_tokens) Supported models: - Claude 3+ models (anthropic.claude-*) - Nova models (amazon.nova-*) - Auto-detection prevents breaking unsupported models Implementation: - System prompts caching via extra_body.prompt_caching.system - Messages caching via extra_body.prompt_caching.messages - Non-streaming and streaming modes - Compatible with reasoning, thinking, and tool calls
This commit is contained in:
73
README.md
73
README.md
@@ -29,6 +29,7 @@ If you find this GitHub repository useful, please consider giving it a free star
|
||||
- [x] Support Application Inference Profiles (**new**)
|
||||
- [x] Support Reasoning (**new**)
|
||||
- [x] Support Interleaved thinking (**new**)
|
||||
- [x] Support Prompt Caching (**new**)
|
||||
|
||||
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
|
||||
|
||||
@@ -221,6 +222,78 @@ print(completion.choices[0].message.content)
|
||||
|
||||
For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).
|
||||
|
||||
### Prompt Caching
|
||||
|
||||
This proxy now supports **Prompt Caching** for Claude and Nova models, which can reduce costs by up to 90% and latency by up to 85% for workloads with repeated prompts.
|
||||
|
||||
**Supported Models:**
|
||||
- Claude 3+ models (Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4, Claude 4.5, etc.)
|
||||
- Nova models (Nova Micro, Nova Lite, Nova Pro, Nova Premier)
|
||||
|
||||
**Enabling Prompt Caching:**
|
||||
|
||||
You can enable prompt caching in two ways:
|
||||
|
||||
1. **Globally via Environment Variable** (set in ECS Task Definition or Lambda):
|
||||
```bash
|
||||
ENABLE_PROMPT_CACHING=true
|
||||
```
|
||||
|
||||
2. **Per-request via `extra_body`** :
|
||||
|
||||
**Python SDK:**
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
# Cache system prompts
|
||||
response = client.chat.completions.create(
|
||||
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an expert assistant with knowledge of..."},
|
||||
{"role": "user", "content": "Help me with this task"}
|
||||
],
|
||||
extra_body={
|
||||
"prompt_caching": {"system": True}
|
||||
}
|
||||
)
|
||||
|
||||
# Check cache hit
|
||||
if response.usage.prompt_tokens_details:
|
||||
cached_tokens = response.usage.prompt_tokens_details.cached_tokens
|
||||
print(f"Cached tokens: {cached_tokens}")
|
||||
```
|
||||
|
||||
**cURL:**
|
||||
```bash
|
||||
curl $OPENAI_BASE_URL/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||
-d '{
|
||||
"model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Long system prompt..."},
|
||||
{"role": "user", "content": "Question"}
|
||||
],
|
||||
"extra_body": {
|
||||
"prompt_caching": {"system": true}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**Cache Options:**
|
||||
- `"prompt_caching": {"system": true}` - Cache system prompts
|
||||
- `"prompt_caching": {"messages": true}` - Cache user messages
|
||||
- `"prompt_caching": {"system": true, "messages": true}` - Cache both
|
||||
|
||||
**Requirements:**
|
||||
- Prompt must be ≥1,024 tokens to enable caching
|
||||
- Cache TTL is 5 minutes (resets on each cache hit)
|
||||
- Nova models have a 20,000 token caching limit
|
||||
|
||||
For more information, see the [Amazon Bedrock Prompt Caching Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html).
|
||||
|
||||
## Other Examples
|
||||
|
||||
### LangChain
|
||||
|
||||
@@ -11,6 +11,13 @@ Parameters:
|
||||
Type: String
|
||||
Default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
Description: The default model ID, please make sure the model ID is supported in the current region
|
||||
EnablePromptCaching:
|
||||
Type: String
|
||||
Default: "false"
|
||||
AllowedValues:
|
||||
- "true"
|
||||
- "false"
|
||||
Description: Enable prompt caching for supported models (Claude, Nova). When enabled, adds cachePoint to system prompts and messages for cost savings.
|
||||
Resources:
|
||||
VPCB9E5F0B4:
|
||||
Type: AWS::EC2::VPC
|
||||
@@ -184,6 +191,8 @@ Resources:
|
||||
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
|
||||
ENABLE_CROSS_REGION_INFERENCE: "true"
|
||||
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
|
||||
ENABLE_PROMPT_CACHING:
|
||||
Ref: EnablePromptCaching
|
||||
MemorySize: 1024
|
||||
PackageType: Image
|
||||
Role:
|
||||
|
||||
@@ -11,6 +11,13 @@ Parameters:
|
||||
Type: String
|
||||
Default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
Description: The default model ID, please make sure the model ID is supported in the current region
|
||||
EnablePromptCaching:
|
||||
Type: String
|
||||
Default: "false"
|
||||
AllowedValues:
|
||||
- "true"
|
||||
- "false"
|
||||
Description: Enable prompt caching for supported models (Claude, Nova). When enabled, adds cachePoint to system prompts and messages for cost savings.
|
||||
Resources:
|
||||
VPCB9E5F0B4:
|
||||
Type: AWS::EC2::VPC
|
||||
@@ -251,6 +258,9 @@ Resources:
|
||||
Value: "true"
|
||||
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
|
||||
Value: "true"
|
||||
- Name: ENABLE_PROMPT_CACHING
|
||||
Value:
|
||||
Ref: EnablePromptCaching
|
||||
Essential: true
|
||||
Image:
|
||||
Ref: ContainerImageUri
|
||||
|
||||
@@ -24,6 +24,7 @@ from api.schema import (
|
||||
ChatStreamResponse,
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
CompletionTokensDetails,
|
||||
Embedding,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
@@ -32,6 +33,7 @@ from api.schema import (
|
||||
ErrorMessage,
|
||||
Function,
|
||||
ImageContent,
|
||||
PromptTokensDetails,
|
||||
ResponseFunction,
|
||||
TextContent,
|
||||
ToolCall,
|
||||
@@ -46,6 +48,7 @@ from api.setting import (
|
||||
DEFAULT_MODEL,
|
||||
ENABLE_CROSS_REGION_INFERENCE,
|
||||
ENABLE_APPLICATION_INFERENCE_PROFILES,
|
||||
ENABLE_PROMPT_CACHING,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -203,6 +206,64 @@ class BedrockModel(BaseChatModel):
|
||||
detail=error,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_prompt_caching(model_id: str) -> bool:
|
||||
"""
|
||||
Check if model supports prompt caching based on model ID pattern.
|
||||
|
||||
Uses pattern matching instead of hardcoded whitelist for better maintainability.
|
||||
Automatically supports new models following the naming convention.
|
||||
|
||||
Supported models:
|
||||
- Claude: anthropic.claude-* (excluding very old versions)
|
||||
- Nova: amazon.nova-*
|
||||
|
||||
Returns:
|
||||
bool: True if model supports prompt caching
|
||||
"""
|
||||
model_lower = model_id.lower()
|
||||
|
||||
# Claude models pattern matching
|
||||
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||
# Exclude very old models that don't support caching
|
||||
excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"]
|
||||
if any(pattern in model_lower for pattern in excluded_patterns):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Nova models pattern matching
|
||||
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
|
||||
return True
|
||||
|
||||
# Future providers can be added here
|
||||
# Example: if "provider.model-name" in model_lower: return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_max_cache_tokens(model_id: str) -> int | None:
|
||||
"""
|
||||
Get maximum cacheable tokens limit for the model.
|
||||
|
||||
Different models have different caching limits:
|
||||
- Claude: No explicit limit mentioned in docs
|
||||
- Nova: 20,000 tokens max
|
||||
|
||||
Returns:
|
||||
int | None: Max tokens, or None if unlimited
|
||||
"""
|
||||
model_lower = model_id.lower()
|
||||
|
||||
# Nova models have 20K limit
|
||||
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
|
||||
return 20_000
|
||||
|
||||
# Claude: No explicit limit
|
||||
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
|
||||
"""Common logic for invoke bedrock models"""
|
||||
if DEBUG:
|
||||
@@ -240,17 +301,32 @@ class BedrockModel(BaseChatModel):
|
||||
response = await self._invoke_bedrock(chat_request)
|
||||
|
||||
output_message = response["output"]["message"]
|
||||
input_tokens = response["usage"]["inputTokens"]
|
||||
output_tokens = response["usage"]["outputTokens"]
|
||||
usage = response["usage"]
|
||||
|
||||
# Extract all token counts
|
||||
output_tokens = usage["outputTokens"]
|
||||
total_tokens = usage["totalTokens"]
|
||||
finish_reason = response["stopReason"]
|
||||
|
||||
# Extract prompt caching metrics if available
|
||||
cache_read_tokens = usage.get("cacheReadInputTokens", 0)
|
||||
cache_creation_tokens = usage.get("cacheCreationInputTokens", 0)
|
||||
|
||||
# Calculate actual prompt tokens
|
||||
# Bedrock's totalTokens includes all: inputTokens + cacheRead + cacheWrite + outputTokens
|
||||
# So: prompt_tokens = totalTokens - outputTokens
|
||||
actual_prompt_tokens = total_tokens - output_tokens
|
||||
|
||||
chat_response = self._create_response(
|
||||
model=chat_request.model,
|
||||
message_id=message_id,
|
||||
content=output_message["content"],
|
||||
finish_reason=finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
input_tokens=actual_prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + chat_response.model_dump_json())
|
||||
@@ -296,24 +372,68 @@ class BedrockModel(BaseChatModel):
|
||||
yield self.stream_response_to_bytes(error_event)
|
||||
|
||||
def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
|
||||
"""Create system prompts.
|
||||
Note that not all models support system prompts.
|
||||
"""Create system prompts with optional prompt caching support.
|
||||
|
||||
example output: [{"text" : system_prompt}]
|
||||
Prompt caching can be enabled via:
|
||||
1. ENABLE_PROMPT_CACHING environment variable (global default)
|
||||
2. extra_body.prompt_caching.system = True/False (per-request override)
|
||||
|
||||
See example:
|
||||
https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
|
||||
Only adds cachePoint if:
|
||||
- Model supports caching (Claude, Nova)
|
||||
- Caching is enabled (ENV or extra_body)
|
||||
- System prompts exist and meet minimum token requirements
|
||||
|
||||
Example output: [{"text" : system_prompt}, {"cachePoint": {"type": "default"}}]
|
||||
|
||||
See: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
|
||||
"""
|
||||
|
||||
system_prompts = []
|
||||
for message in chat_request.messages:
|
||||
if message.role != "system":
|
||||
# ignore system messages here
|
||||
continue
|
||||
if not isinstance(message.content, str):
|
||||
raise TypeError(f"System message content must be a string, got {type(message.content).__name__}")
|
||||
system_prompts.append({"text": message.content})
|
||||
|
||||
if not system_prompts:
|
||||
return system_prompts
|
||||
|
||||
# Check if model supports prompt caching
|
||||
if not self._supports_prompt_caching(chat_request.model):
|
||||
return system_prompts
|
||||
|
||||
# Determine if caching should be enabled
|
||||
cache_enabled = ENABLE_PROMPT_CACHING # Default from ENV
|
||||
|
||||
# Check for extra_body override
|
||||
if chat_request.extra_body and isinstance(chat_request.extra_body, dict):
|
||||
prompt_caching = chat_request.extra_body.get("prompt_caching", {})
|
||||
if "system" in prompt_caching:
|
||||
# extra_body explicitly controls caching
|
||||
cache_enabled = prompt_caching.get("system") is True
|
||||
|
||||
if not cache_enabled:
|
||||
return system_prompts
|
||||
|
||||
# Estimate total tokens for limit check
|
||||
total_text = " ".join(p.get("text", "") for p in system_prompts)
|
||||
estimated_tokens = len(total_text.split()) * 1.3 # Rough estimate
|
||||
|
||||
# Check token limits (Nova has 20K limit)
|
||||
max_tokens = self._get_max_cache_tokens(chat_request.model)
|
||||
if max_tokens and estimated_tokens > max_tokens:
|
||||
logger.warning(
|
||||
f"System prompts (~{estimated_tokens:.0f} tokens) exceed model cache limit ({max_tokens} tokens). "
|
||||
f"Caching will still be attempted but may not work optimally."
|
||||
)
|
||||
# Still add cachePoint - let Bedrock handle the limit
|
||||
|
||||
# Add cache checkpoint after system prompts
|
||||
system_prompts.append({"cachePoint": {"type": "default"}})
|
||||
|
||||
if DEBUG:
|
||||
logger.info(f"Added cachePoint to system prompts for model {chat_request.model}")
|
||||
|
||||
return system_prompts
|
||||
|
||||
def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
|
||||
@@ -402,7 +522,7 @@ class BedrockModel(BaseChatModel):
|
||||
else:
|
||||
# ignore others, such as system messages
|
||||
continue
|
||||
return self._reframe_multi_payloard(messages)
|
||||
return self._reframe_multi_payloard(messages, chat_request)
|
||||
|
||||
def _extract_tool_content(self, content) -> str:
|
||||
"""Extract text content from various OpenAI SDK tool message formats.
|
||||
@@ -455,7 +575,7 @@ class BedrockModel(BaseChatModel):
|
||||
# Return a safe fallback
|
||||
return str(content) if content is not None else ""
|
||||
|
||||
def _reframe_multi_payloard(self, messages: list) -> list:
|
||||
def _reframe_multi_payloard(self, messages: list, chat_request: ChatRequest = None) -> list:
|
||||
"""Receive messages and reformat them to comply with the Claude format
|
||||
|
||||
With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but
|
||||
@@ -510,6 +630,29 @@ class BedrockModel(BaseChatModel):
|
||||
{"role": current_role, "content": current_content}
|
||||
)
|
||||
|
||||
# Add cachePoint to messages if enabled and supported
|
||||
if chat_request and reformatted_messages:
|
||||
if not self._supports_prompt_caching(chat_request.model):
|
||||
return reformatted_messages
|
||||
|
||||
# Determine if messages caching should be enabled
|
||||
cache_enabled = ENABLE_PROMPT_CACHING
|
||||
|
||||
if chat_request.extra_body and isinstance(chat_request.extra_body, dict):
|
||||
prompt_caching = chat_request.extra_body.get("prompt_caching", {})
|
||||
if "messages" in prompt_caching:
|
||||
cache_enabled = prompt_caching.get("messages") is True
|
||||
|
||||
if cache_enabled:
|
||||
# Add cachePoint to the last user message content
|
||||
for msg in reversed(reformatted_messages):
|
||||
if msg["role"] == "user" and msg.get("content"):
|
||||
# Add cachePoint at the end of user message content
|
||||
msg["content"].append({"cachePoint": {"type": "default"}})
|
||||
if DEBUG:
|
||||
logger.info(f"Added cachePoint to last user message for model {chat_request.model}")
|
||||
break
|
||||
|
||||
return reformatted_messages
|
||||
|
||||
def _parse_request(self, chat_request: ChatRequest) -> dict:
|
||||
@@ -547,9 +690,12 @@ class BedrockModel(BaseChatModel):
|
||||
"inferenceConfig": inference_config,
|
||||
}
|
||||
if chat_request.reasoning_effort:
|
||||
# From OpenAI api, the max_token is not supported in reasoning mode
|
||||
# Use max_completion_tokens if provided.
|
||||
# reasoning_effort is supported by Claude and DeepSeek v3
|
||||
# Different models use different formats
|
||||
model_lower = chat_request.model.lower()
|
||||
|
||||
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||
# Claude format: reasoning_config = object with budget_tokens
|
||||
max_tokens = (
|
||||
chat_request.max_completion_tokens
|
||||
if chat_request.max_completion_tokens
|
||||
@@ -565,6 +711,18 @@ class BedrockModel(BaseChatModel):
|
||||
args["additionalModelRequestFields"] = {
|
||||
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||
}
|
||||
elif "deepseek.v3" in model_lower or "deepseek.deepseek-v3" in model_lower:
|
||||
# DeepSeek v3 format: reasoning_config = string ('low', 'medium', 'high')
|
||||
# From Bedrock Playground: {"reasoning_config": "high"}
|
||||
args["additionalModelRequestFields"] = {
|
||||
"reasoning_config": chat_request.reasoning_effort # Direct string: low/medium/high
|
||||
}
|
||||
if DEBUG:
|
||||
logger.info(f"Applied reasoning_config={chat_request.reasoning_effort} for DeepSeek v3")
|
||||
else:
|
||||
# For other models (Qwen, etc.), ignore reasoning_effort parameter
|
||||
if DEBUG:
|
||||
logger.info(f"reasoning_effort parameter ignored for model {chat_request.model} (not supported)")
|
||||
# add tool config
|
||||
if chat_request.tools:
|
||||
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
||||
@@ -585,16 +743,42 @@ class BedrockModel(BaseChatModel):
|
||||
raise ValueError("tool_choice must contain 'function' key when specifying a specific tool")
|
||||
tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
|
||||
args["toolConfig"] = tool_config
|
||||
# add Additional fields to enable extend thinking
|
||||
# Add additional fields to enable extend thinking or other model-specific features
|
||||
if chat_request.extra_body:
|
||||
# reasoning_config will not be used
|
||||
args["additionalModelRequestFields"] = chat_request.extra_body
|
||||
# Filter out prompt_caching (our control field, not for Bedrock)
|
||||
additional_fields = {
|
||||
k: v for k, v in chat_request.extra_body.items()
|
||||
if k != "prompt_caching"
|
||||
}
|
||||
|
||||
if additional_fields:
|
||||
# Only set additionalModelRequestFields if there are actual fields to pass
|
||||
args["additionalModelRequestFields"] = additional_fields
|
||||
|
||||
# Extended thinking doesn't support both temperature and topP
|
||||
# Remove topP to avoid validation error
|
||||
if "thinking" in chat_request.extra_body:
|
||||
if "thinking" in additional_fields:
|
||||
inference_config.pop("topP", None)
|
||||
|
||||
return args
|
||||
|
||||
def _estimate_reasoning_tokens(self, content: list[dict]) -> int:
|
||||
"""
|
||||
Estimate reasoning tokens from reasoningContent blocks.
|
||||
|
||||
Bedrock doesn't separately report reasoning tokens, so we estimate
|
||||
them using tiktoken to maintain OpenAI API compatibility.
|
||||
"""
|
||||
reasoning_text = ""
|
||||
for block in content:
|
||||
if "reasoningContent" in block:
|
||||
reasoning_text += block["reasoningContent"]["reasoningText"].get("text", "")
|
||||
|
||||
if reasoning_text:
|
||||
# Use tiktoken to estimate token count
|
||||
return len(ENCODER.encode(reasoning_text))
|
||||
return 0
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
model: str,
|
||||
@@ -603,6 +787,9 @@ class BedrockModel(BaseChatModel):
|
||||
finish_reason: str | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> ChatResponse:
|
||||
message = ChatResponseMessage(
|
||||
role="assistant",
|
||||
@@ -642,6 +829,25 @@ class BedrockModel(BaseChatModel):
|
||||
message.content = f"<think>{message.reasoning_content}</think>{message.content}"
|
||||
message.reasoning_content = None
|
||||
|
||||
# Create prompt_tokens_details if cache metrics are available
|
||||
prompt_tokens_details = None
|
||||
if cache_read_tokens > 0 or cache_creation_tokens > 0:
|
||||
# Map Bedrock cache metrics to OpenAI format
|
||||
# cached_tokens represents tokens read from cache (cache hits)
|
||||
prompt_tokens_details = PromptTokensDetails(
|
||||
cached_tokens=cache_read_tokens,
|
||||
audio_tokens=0,
|
||||
)
|
||||
|
||||
# Create completion_tokens_details if reasoning content exists
|
||||
completion_tokens_details = None
|
||||
reasoning_tokens = self._estimate_reasoning_tokens(content) if content else 0
|
||||
if reasoning_tokens > 0:
|
||||
completion_tokens_details = CompletionTokensDetails(
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
audio_tokens=0,
|
||||
)
|
||||
|
||||
response = ChatResponse(
|
||||
id=message_id,
|
||||
model=model,
|
||||
@@ -656,7 +862,9 @@ class BedrockModel(BaseChatModel):
|
||||
usage=Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
total_tokens=total_tokens if total_tokens > 0 else input_tokens + output_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
),
|
||||
)
|
||||
response.system_fingerprint = "fp"
|
||||
@@ -744,14 +952,36 @@ class BedrockModel(BaseChatModel):
|
||||
metadata = chunk["metadata"]
|
||||
if "usage" in metadata:
|
||||
# token usage
|
||||
usage_data = metadata["usage"]
|
||||
|
||||
# Extract prompt caching metrics if available
|
||||
cache_read_tokens = usage_data.get("cacheReadInputTokens", 0)
|
||||
cache_creation_tokens = usage_data.get("cacheCreationInputTokens", 0)
|
||||
|
||||
# Create prompt_tokens_details if cache metrics are available
|
||||
prompt_tokens_details = None
|
||||
if cache_read_tokens > 0 or cache_creation_tokens > 0:
|
||||
prompt_tokens_details = PromptTokensDetails(
|
||||
cached_tokens=cache_read_tokens,
|
||||
audio_tokens=0,
|
||||
)
|
||||
|
||||
# Calculate actual prompt tokens
|
||||
# Bedrock's totalTokens includes all tokens
|
||||
# prompt_tokens = totalTokens - outputTokens
|
||||
total_tokens = usage_data["totalTokens"]
|
||||
output_tokens = usage_data["outputTokens"]
|
||||
actual_prompt_tokens = total_tokens - output_tokens
|
||||
|
||||
return ChatStreamResponse(
|
||||
id=message_id,
|
||||
model=model_id,
|
||||
choices=[],
|
||||
usage=Usage(
|
||||
prompt_tokens=metadata["usage"]["inputTokens"],
|
||||
completion_tokens=metadata["usage"]["outputTokens"],
|
||||
total_tokens=metadata["usage"]["totalTokens"],
|
||||
prompt_tokens=actual_prompt_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
),
|
||||
)
|
||||
if message:
|
||||
|
||||
@@ -110,10 +110,24 @@ class ChatRequest(BaseModel):
|
||||
extra_body: dict | None = None
|
||||
|
||||
|
||||
class PromptTokensDetails(BaseModel):
|
||||
"""Details about prompt tokens usage, following OpenAI API format."""
|
||||
cached_tokens: int = 0
|
||||
audio_tokens: int = 0
|
||||
|
||||
|
||||
class CompletionTokensDetails(BaseModel):
|
||||
"""Details about completion tokens usage, following OpenAI API format."""
|
||||
reasoning_tokens: int = 0
|
||||
audio_tokens: int = 0
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: PromptTokensDetails | None = None
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
|
||||
|
||||
class ChatResponseMessage(BaseModel):
|
||||
|
||||
@@ -15,3 +15,4 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240
|
||||
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
||||
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
|
||||
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
|
||||
ENABLE_PROMPT_CACHING = os.environ.get("ENABLE_PROMPT_CACHING", "false").lower() != "false"
|
||||
|
||||
Reference in New Issue
Block a user