diff --git a/src/api/models/base.py b/src/api/models/base.py
index 3c492b4..9dc74dd 100644
--- a/src/api/models/base.py
+++ b/src/api/models/base.py
@@ -29,11 +29,17 @@ class BaseChatModel(ABC):
"""Handle a basic chat completion requests with stream response."""
pass
- def _generate_message_id(self) -> str:
+ @staticmethod
+ def generate_message_id() -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
- def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
- return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
+ @staticmethod
+ def stream_response_to_bytes(
+ response: ChatStreamResponse | None = None
+ ) -> bytes:
+ if response:
+ return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
+ return "data: [DONE]\n\n".encode("utf-8")
class BaseEmbeddingsModel(ABC):
@@ -46,6 +52,3 @@ class BaseEmbeddingsModel(ABC):
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
-
- def _generate_message_id(self) -> str:
- return "embeddings-" + str(uuid.uuid4())[:8]
diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py
index db33cc9..11f4a90 100644
--- a/src/api/models/bedrock.py
+++ b/src/api/models/bedrock.py
@@ -2,7 +2,7 @@ import base64
import json
import logging
import re
-from abc import ABC
+from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, Literal
import boto3
@@ -54,6 +54,8 @@ SUPPORTED_BEDROCK_MODELS = {
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
"mistral.mistral-large-2402-v1:0": "Mistral Large",
+ "cohere.command-r-v1:0": "Command R",
+ "cohere.command-r-plus-v1:0": "Command R+",
}
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
@@ -72,28 +74,157 @@ class BedrockModel(BaseChatModel, ABC):
accept = "application/json"
content_type = "application/json"
- def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
- body = json.dumps(args)
+ # Default field name to get the response message
+ text_field_name = "text"
+
+ # Default field name to get the response finish reason
+ finish_reason_field_name = "finish_reason"
+
+ @abstractmethod
+ def compose_request_body(self, chat_request: ChatRequest) -> str:
+ """Since the request body to Bedrock varies,
+ each model should implement this to compose the request body.
+
+ :param chat_request:
+ :return: request body as a string
+ """
+ raise NotImplementedError()
+
+ def chat(self, chat_request: ChatRequest) -> ChatResponse:
+ """Default implementation for Chat API."""
+ if DEBUG:
+ logger.info("Raw request: " + chat_request.model_dump_json())
+ request_body = self.compose_request_body(chat_request)
+
+ if DEBUG:
+ logger.info("Bedrock request: " + request_body)
+
+ response = self.invoke_model(
+ request_body=request_body,
+ model_id=chat_request.model,
+ )
+ message_id = self.generate_message_id()
+ return self.parse_response(chat_request, response, message_id)
+
+ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
+ """Default implementation for Chat Stream API"""
+ if DEBUG:
+ logger.info("Raw request: " + chat_request.model_dump_json())
+ request_body = self.compose_request_body(chat_request)
+ response = self.invoke_model(
+ request_body=request_body,
+ model_id=chat_request.model,
+ with_stream=True,
+ )
+
+ message_id = self.generate_message_id()
+ for stream_response in self.parse_stream_response(
+ chat_request, response, message_id
+ ):
+ if stream_response.choices:
+ yield self.stream_response_to_bytes(stream_response)
+ elif (
+ chat_request.stream_options
+ and chat_request.stream_options.include_usage
+ ):
+ # An empty choices for Usage as per OpenAI doc below:
+ # if you set stream_options: {"include_usage": true}.
+ # an additional chunk will be streamed before the data: [DONE] message.
+ # The usage field on this chunk shows the token usage statistics for the entire request,
+ # and the choices field will always be an empty array.
+ # All other chunks will also include a usage field, but with a null value.
+ yield self.stream_response_to_bytes(stream_response)
+ # return an [DONE] message at the end.
+ yield self.stream_response_to_bytes()
+
+ def get_message_text(self, response_body: dict) -> str | None:
+ """Default func to get the response message.
+
+ Ideally, only the field name should be changed."""
+ return response_body.get(self.text_field_name)
+
+ def get_message_finish_reason(self, response_body: dict) -> str | None:
+ """Default func to get the finish message.
+
+ Ideally, only the field name should be changed."""
+ return response_body.get(self.finish_reason_field_name)
+
+ def get_message_usage(self, response_body: dict) -> tuple[int, int]:
+ """Default func to get the finish message.
+
+ Can be overridden in the detail model for complex cases."""
+ input_tokens = int(response_body.get("prompt_token_count", "0"))
+ output_tokens = int(response_body.get("generation_token_count", "0"))
+ return input_tokens, output_tokens
+
+ def parse_response(
+ self, chat_request: ChatRequest, service_response: dict, message_id: str
+ ) -> ChatResponse:
+ response_body = json.loads(service_response.get("body").read())
+ if DEBUG:
+ logger.info("Bedrock response body: " + str(response_body))
+
+ input_tokens, output_tokens = self.get_message_usage(response_body)
+ return self.create_response(
+ model=chat_request.model,
+ message_id=message_id,
+ message=self.get_message_text(response_body),
+ finish_reason=self.get_message_finish_reason(response_body),
+ input_tokens=input_tokens,
+ output_tokens=output_tokens,
+ )
+
+ def parse_stream_response(
+ self, chat_request: ChatRequest, service_response: dict, message_id: str
+ ) -> Iterable[ChatStreamResponse]:
+
+ chunk_id = 0
+ for event in service_response.get("body"):
+ if DEBUG:
+ logger.info("Bedrock response chunk: " + str(event))
+ chunk = json.loads(event["chunk"]["bytes"])
+ chunk_id += 1
+
+ response = self.create_response_stream(
+ model=chat_request.model,
+ message_id=message_id,
+ chunk_message=self.get_message_text(chunk),
+ finish_reason=self.get_message_finish_reason(chunk),
+ )
+ yield response
+ # Get the usage for streaming response anyway.
+ if "amazon-bedrock-invocationMetrics" in chunk:
+ yield self.create_response_stream(
+ model=chat_request.model,
+ message_id=message_id,
+ input_tokens=chunk["amazon-bedrock-invocationMetrics"][
+ "inputTokenCount"
+ ],
+ output_tokens=chunk["amazon-bedrock-invocationMetrics"][
+ "outputTokenCount"
+ ],
+ )
+
+ def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False):
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
- logger.info("Bedrock request body: " + body)
+ logger.info("Bedrock request body: " + request_body)
try:
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
- body=body,
+ body=request_body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
- body=body,
+ body=request_body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
- print("Validation Exception")
- print(e)
+ logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
@@ -101,8 +232,7 @@ class BedrockModel(BaseChatModel, ABC):
@staticmethod
def merge_message(messages: list[dict]) -> list[dict]:
- """Merge the request messages with the same role as previous message
- """
+ """Merge the request messages with the same role as previous message"""
merged_messages = []
prev_role = None
merged_content = ""
@@ -110,10 +240,11 @@ class BedrockModel(BaseChatModel, ABC):
for message in messages:
role = message["role"]
content = message["content"]
-
if role != prev_role or isinstance(content, list):
if prev_role:
- merged_messages.append({"role": prev_role, "content": merged_content})
+ merged_messages.append(
+ {"role": prev_role, "content": merged_content}
+ )
if isinstance(content, str):
merged_content = content
prev_role = role
@@ -132,7 +263,7 @@ class BedrockModel(BaseChatModel, ABC):
return merged_messages
@staticmethod
- def _create_response(
+ def create_response(
model: str,
message_id: str,
message: str | None = None,
@@ -144,15 +275,17 @@ class BedrockModel(BaseChatModel, ABC):
response = ChatResponse(
id=message_id,
model=model,
- choices=[Choice(
- index=0,
- message=ChatResponseMessage(
- role="assistant",
- tool_calls=tools,
- content=message,
- ),
- finish_reason="tool_calls" if tools else finish_reason,
- )],
+ choices=[
+ Choice(
+ index=0,
+ message=ChatResponseMessage(
+ role="assistant",
+ tool_calls=tools,
+ content=message,
+ ),
+ finish_reason="tool_calls" if tools else finish_reason,
+ )
+ ],
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
@@ -164,26 +297,42 @@ class BedrockModel(BaseChatModel, ABC):
return response
@staticmethod
- def _create_response_stream(
+ def create_response_stream(
model: str,
message_id: str,
chunk_message: str | None = None,
finish_reason: str | None = None,
tools: list[ToolCall] | None = None,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
) -> ChatStreamResponse:
- response = ChatStreamResponse(
- id=message_id,
- model=model,
- choices=[ChoiceDelta(
- index=0,
- delta=ChatResponseMessage(
- role="assistant",
- tool_calls=tools,
- content=chunk_message,
+ if chunk_message or finish_reason or tools:
+ response = ChatStreamResponse(
+ id=message_id,
+ model=model,
+ choices=[
+ ChoiceDelta(
+ index=0,
+ delta=ChatResponseMessage(
+ role="assistant",
+ tool_calls=tools,
+ content=chunk_message,
+ ),
+ finish_reason=finish_reason,
+ )
+ ],
+ )
+ else:
+ response = ChatStreamResponse(
+ id=message_id,
+ model=model,
+ choices=[],
+ usage=Usage(
+ prompt_tokens=input_tokens,
+ completion_tokens=output_tokens,
+ total_tokens=input_tokens + output_tokens,
),
- finish_reason="tool_calls" if tools else finish_reason,
- )],
- )
+ )
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
@@ -201,7 +350,7 @@ Please think if you need to use a tool or not for user's question, you must:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text."""
- def _parse_args(self, chat_request: ChatRequest) -> dict:
+ def compose_request_body(self, chat_request: ChatRequest) -> str:
args = {
"anthropic_version": self.anthropic_version,
"max_tokens": chat_request.max_tokens,
@@ -239,8 +388,8 @@ Please think if you need to use a tool or not for user's question, you must:
{
"role": "user",
"content": "[Tool result with matching id `{}` of `{}`] ".format(
- message.tool_call_id,
- message.content),
+ message.tool_call_id, message.content
+ ),
}
)
else:
@@ -253,74 +402,86 @@ Please think if you need to use a tool or not for user's question, you must:
[tool.function.model_dump() for tool in chat_request.tools]
)
system_prompt += self.tool_prompt.format(tools=tools_str)
- converted_messages.append({
- 'role': 'assistant',
- 'content': ''
- })
- args["stop_sequences"] = ['']
+ converted_messages.append({"role": "assistant", "content": ""})
+ args["stop_sequences"] = [""]
args["messages"] = self.merge_message(converted_messages)
if system_prompt:
if DEBUG:
logger.info("System Prompt: " + system_prompt)
args["system"] = system_prompt
- return args
+ return json.dumps(args)
- def chat(self, chat_request: ChatRequest) -> ChatResponse:
- if DEBUG:
- logger.info("Raw request: " + chat_request.model_dump_json())
- response = self._invoke_model(
- args=self._parse_args(chat_request), model_id=chat_request.model
- )
- response_body = json.loads(response.get("body").read())
+ def parse_response(
+ self, chat_request: ChatRequest, service_response: dict, message_id: str
+ ) -> ChatResponse:
+ response_body = json.loads(service_response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message = response_body["content"][0]["text"]
-
+ finish_reason = response_body["stop_reason"]
tools = None
if chat_request.tools:
if message.startswith("Y"):
tools = self._parse_tool_message(message)
message = None
+ finish_reason = "tool_calls"
elif message.startswith("N"):
message = message[8:].lstrip("\n")
- return self._create_response(
+ return self.create_response(
model=chat_request.model,
- message_id=response_body["id"],
+ message_id=message_id,
message=message,
tools=tools,
- finish_reason=response_body["stop_reason"],
+ finish_reason=finish_reason,
input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"],
)
- def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
- if DEBUG:
- logger.info("Raw request: " + chat_request.model_dump_json())
- response = self._invoke_model(
- args=self._parse_args(chat_request),
- model_id=chat_request.model,
- with_stream=True,
- )
- msg_id = ""
+ def parse_stream_response(
+ self, chat_request: ChatRequest, service_response: dict, message_id: str
+ ) -> Iterable[ChatStreamResponse]:
+
chunk_id = 0
tool_message = ""
first_token = True
index = 0
- for event in response.get("body"):
+ for event in service_response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
- if chunk["type"] == "message_start":
- msg_id = chunk["message"]["id"]
- continue
+ if chunk["type"] == "message_stop":
+ # Get the usage for streaming response anyway.
+ if "amazon-bedrock-invocationMetrics" in chunk:
+ yield self.create_response_stream(
+ model=chat_request.model,
+ message_id=message_id,
+ input_tokens=chunk["amazon-bedrock-invocationMetrics"][
+ "inputTokenCount"
+ ],
+ output_tokens=chunk["amazon-bedrock-invocationMetrics"][
+ "outputTokenCount"
+ ],
+ )
+ break
- if chunk["type"] == "message_delta":
+ elif chunk["type"] == "message_delta":
chunk_message = ""
finish_reason = chunk["delta"]["stop_reason"]
+ # Send tool message first if any.
+ if chat_request.tools and tool_message:
+ tools = self._parse_tool_message(tool_message)
+ finish_reason = "tool_calls"
+ response = self.create_response_stream(
+ model=chat_request.model,
+ message_id=message_id,
+ tools=tools,
+ )
+ yield response
+
elif chunk["type"] == "content_block_delta":
chunk_message = chunk["delta"]["text"]
finish_reason = None
@@ -343,22 +504,14 @@ Please think if you need to use a tool or not for user's question, you must:
first_token = False
else:
continue
- response = self._create_response_stream(
+ response = self.create_response_stream(
model=chat_request.model,
- message_id=msg_id,
+ message_id=message_id,
chunk_message=chunk_message,
finish_reason=finish_reason,
)
- yield self._stream_response_to_bytes(response)
- if chat_request.tools and tool_message:
- tools = self._parse_tool_message(tool_message)
- response = self._create_response_stream(
- model=chat_request.model,
- message_id=msg_id,
- tools=tools,
- )
- yield self._stream_response_to_bytes(response)
+ yield response
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
if DEBUG:
@@ -367,9 +520,7 @@ Please think if you need to use a tool or not for user's question, you must:
tool_messages = tool_message[tool_message.rindex("") + len(""):]
function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {}))
- function = ResponseFunction(
- name=function["name"], arguments=args
- )
+ function = ResponseFunction(name=function["name"], arguments=args)
return [
ToolCall(
@@ -400,7 +551,7 @@ Please think if you need to use a tool or not for user's question, you must:
# Check if the request was successful
if response.status_code == 200:
- content_type = response.headers.get('Content-Type')
+ content_type = response.headers.get("Content-Type")
if not content_type.startswith("image"):
content_type = "image/jpeg"
# Get the image content
@@ -437,6 +588,8 @@ Please think if you need to use a tool or not for user's question, you must:
class LlamaModel(BedrockModel):
+ text_field_name = "generation"
+ finish_reason_field_name = "stop_reason"
@staticmethod
def create_llama3_prompt(chat_request: ChatRequest) -> str:
@@ -460,7 +613,9 @@ class LlamaModel(BedrockModel):
prompt_lines = []
for msg in chat_request.messages:
- prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
+ prompt_lines.append(
+ f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>"
+ )
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
prompt = bos_token + "".join(prompt_lines)
if DEBUG:
@@ -512,59 +667,24 @@ class LlamaModel(BedrockModel):
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
- def _parse_args(self, chat_request: ChatRequest) -> dict:
+ def compose_request_body(self, chat_request: ChatRequest) -> str:
if chat_request.model.startswith("meta.llama2"):
prompt = self.create_llama2_prompt(chat_request)
else:
prompt = self.create_llama3_prompt(chat_request)
- # Currently, there is no way to set stop sequence for Llama 3 models.
- return {
+ args = {
"prompt": prompt,
"max_gen_len": chat_request.max_tokens,
"temperature": chat_request.temperature,
"top_p": chat_request.top_p,
}
-
- def chat(self, chat_request: ChatRequest) -> ChatResponse:
- response = self._invoke_model(
- args=self._parse_args(chat_request), model_id=chat_request.model
- )
- response_body = json.loads(response.get("body").read())
- if DEBUG:
- logger.info("Bedrock response body: " + str(response_body))
- message_id = self._generate_message_id()
-
- return self._create_response(
- model=chat_request.model,
- message=response_body["generation"],
- message_id=message_id,
- input_tokens=response_body["prompt_token_count"],
- output_tokens=response_body["generation_token_count"],
- )
-
- def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
- response = self._invoke_model(
- args=self._parse_args(chat_request),
- model_id=chat_request.model,
- with_stream=True,
- )
- msg_id = ""
- chunk_id = 0
- for event in response.get("body"):
- if DEBUG:
- logger.info("Bedrock response chunk: " + str(event))
- chunk = json.loads(event["chunk"]["bytes"])
- chunk_id += 1
- response = self._create_response_stream(
- model=chat_request.model,
- message_id=msg_id,
- chunk_message=chunk["generation"],
- finish_reason=chunk["stop_reason"],
- )
- yield self._stream_response_to_bytes(response)
+ return json.dumps(args)
class MistralModel(BedrockModel):
+ text_field_name = "text"
+ finish_reason_field_name = "stop_reason"
+
def _convert_prompt(self, chat_request: ChatRequest) -> str:
"""Create a prompt message follow below example:
@@ -609,51 +729,54 @@ class MistralModel(BedrockModel):
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
- def _parse_args(self, chat_request: ChatRequest) -> dict:
+ def compose_request_body(self, chat_request: ChatRequest) -> str:
prompt = self._convert_prompt(chat_request)
- return {
+ args = {
"prompt": prompt,
"max_tokens": chat_request.max_tokens,
"temperature": chat_request.temperature,
"top_p": chat_request.top_p,
}
+ return json.dumps(args)
- def chat(self, chat_request: ChatRequest) -> ChatResponse:
+ def get_message_text(self, response_body: dict) -> str | None:
+ return super().get_message_text(response_body["outputs"][0])
- response = self._invoke_model(
- args=self._parse_args(chat_request), model_id=chat_request.model
- )
- response_body = json.loads(response.get("body").read())
- if DEBUG:
- logger.info("Bedrock response body: " + str(response_body))
- message_id = self._generate_message_id()
+ def get_message_finish_reason(self, response_body: dict) -> str | None:
+ return super().get_message_finish_reason(response_body["outputs"][0])
- return self._create_response(
- model=chat_request.model,
- message=response_body["outputs"][0]["text"],
- message_id=message_id,
- )
+ def get_message_usage(self, response_body: dict) -> tuple[int, int]:
+ # Mistral/Mixtral does not provide info about usage
+ return 0, 0
- def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
- response = self._invoke_model(
- args=self._parse_args(chat_request),
- model_id=chat_request.model,
- with_stream=True,
- )
- msg_id = ""
- chunk_id = 0
- for event in response.get("body"):
- if DEBUG:
- logger.info("Bedrock response chunk: " + str(event))
- chunk = json.loads(event["chunk"]["bytes"])
- chunk_id += 1
- response = self._create_response_stream(
- model=chat_request.model,
- message_id=msg_id,
- chunk_message=chunk["outputs"][0]["text"],
- finish_reason=chunk["outputs"][0]["stop_reason"],
+
+class CohereCommandModel(BedrockModel):
+
+ def _parse_message(self, message) -> dict:
+ if message.role not in ["user", "assistant"]:
+ raise HTTPException(
+ status_code=400, detail="Only user or assistant message is supported"
)
- yield self._stream_response_to_bytes(response)
+ return {
+ "role": "USER" if message.role == "user" else "CHATBOT",
+ "message": message.content,
+ }
+
+ def compose_request_body(self, chat_request: ChatRequest) -> str:
+ messages = chat_request.messages
+ if messages[-1].role != "user":
+ raise HTTPException(
+ status_code=400, detail="Last message should be a valid user message"
+ )
+ chat_history = [self._parse_message(msg) for msg in messages[:-1]]
+ args = {
+ "message": messages[-1].content,
+ "chat_history": chat_history,
+ "max_tokens": chat_request.max_tokens,
+ "temperature": chat_request.temperature,
+ "p": chat_request.top_p,
+ }
+ return json.dumps(args)
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
@@ -673,8 +796,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
- print("Validation Exception")
- print(e)
+ logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
@@ -705,7 +827,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
total_tokens=input_tokens + output_tokens,
),
)
-
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
@@ -799,24 +920,22 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def get_model(model_id: str) -> BedrockModel:
- model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG:
- logger.info("model name is " + model_name)
- # Not using start_with here in case of complex scenarios.
- # The downside is to change this everytime for a new model supported.
- match model_name:
- case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
- return ClaudeModel()
- case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
- return LlamaModel()
- case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
- return MistralModel()
- case _:
- logger.error("Unsupported model id " + model_id)
- raise HTTPException(
- status_code=400,
- detail="Unsupported model id " + model_id,
- )
+ logger.info("model id is " + model_id)
+ if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
+ logger.error("Unsupported model id " + model_id)
+ raise HTTPException(
+ status_code=400,
+ detail="Unsupported model id " + model_id,
+ )
+ if model_id.startswith("anthropic.claude"):
+ return ClaudeModel()
+ elif model_id.startswith("meta.llama"):
+ return LlamaModel()
+ elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"):
+ return MistralModel()
+ elif model_id.startswith("cohere.command-r"):
+ return CohereCommandModel()
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
diff --git a/src/api/schema.py b/src/api/schema.py
index fc4e786..da133d5 100644
--- a/src/api/schema.py
+++ b/src/api/schema.py
@@ -79,12 +79,17 @@ class Tool(BaseModel):
function: Function
+class StreamOptions(BaseModel):
+ include_usage: bool = True
+
+
class ChatRequest(BaseModel):
messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage]
model: str
frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
stream: bool | None = False
+ stream_options: StreamOptions | None = None
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
user: str | None = None # Not used
@@ -138,6 +143,7 @@ class ChatResponse(BaseChatResponse):
class ChatStreamResponse(BaseChatResponse):
choices: list[ChoiceDelta]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
+ usage: Usage | None = None
class EmbeddingsRequest(BaseModel):
diff --git a/src/api/setting.py b/src/api/setting.py
index bf056d8..56e99c4 100644
--- a/src/api/setting.py
+++ b/src/api/setting.py
@@ -26,6 +26,8 @@ List of Amazon Bedrock models currently supported:
- mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-large-2402-v1:0
+- cohere.command-r-v1:0
+- cohere.command-r-plus-v1:0
# Embeddings
- cohere.embed-multilingual-v3