diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 4b7579b..65f8e90 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -24,6 +24,7 @@ from api.schema import ( Choice, ChoiceDelta, CompletionTokensDetails, + DeveloperMessage, Embedding, EmbeddingsRequest, EmbeddingsResponse, @@ -455,7 +456,7 @@ class BedrockModel(BaseChatModel): """ system_prompts = [] for message in chat_request.messages: - if message.role != "system": + if message.role not in ("system", "developer"): continue if not isinstance(message.content, str): raise TypeError(f"System message content must be a string, got {type(message.content).__name__}") diff --git a/src/api/schema.py b/src/api/schema.py index ffcbab9..ca271c5 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -75,6 +75,12 @@ class ToolMessage(BaseModel): tool_call_id: str +class DeveloperMessage(BaseModel): + name: str | None = None + role: Literal["developer"] = "developer" + content: str + + class Function(BaseModel): name: str description: str | None = None @@ -91,7 +97,7 @@ class StreamOptions(BaseModel): class ChatRequest(BaseModel): - messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage] + messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage | DeveloperMessage] model: str = DEFAULT_MODEL frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used