From c3b7395028305c94baac6a4fbbdbeb7f46e968bd Mon Sep 17 00:00:00 2001 From: Joao Galego Date: Thu, 28 Mar 2024 08:05:39 +0000 Subject: [PATCH] Added missing input type checking (Cohere Embed); Removed fingerprint from BaseEmbeddingsResponse --- src/api/models/bedrock.py | 6 +++++- src/api/schema.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 75f4f3f..48d54cc 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -483,8 +483,12 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: class CohereEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: + if isinstance(embeddings_request.input, str): + texts = [embeddings_request.input] + elif isinstance(embeddings_request.input, list): + texts = embeddings_request.input args = { - "texts": embeddings_request.input, + "texts": texts, "input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document", "truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE", } diff --git a/src/api/schema.py b/src/api/schema.py index 17969c7..05f3d1f 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -93,7 +93,6 @@ class EmbeddingsRequest(BaseModel): class BaseEmbeddingsResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str - system_fingerprint: str = "fp_e97c09dd4e26" class Embedding(BaseModel):