diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 75f4f3f..48d54cc 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -483,8 +483,12 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: class CohereEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: + if isinstance(embeddings_request.input, str): + texts = [embeddings_request.input] + elif isinstance(embeddings_request.input, list): + texts = embeddings_request.input args = { - "texts": embeddings_request.input, + "texts": texts, "input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document", "truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE", } diff --git a/src/api/schema.py b/src/api/schema.py index 17969c7..05f3d1f 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -93,7 +93,6 @@ class EmbeddingsRequest(BaseModel): class BaseEmbeddingsResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str - system_fingerprint: str = "fp_e97c09dd4e26" class Embedding(BaseModel):