Update embedding API
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterable
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import boto3
|
||||
import tiktoken
|
||||
|
||||
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||
from api.schema import (
|
||||
@@ -45,10 +46,13 @@ SUPPORTED_BEDROCK_MODELS = {
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
||||
"cohere.embed-english-v3": "Cohere Embed English",
|
||||
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
||||
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||
# Disable Titan embedding.
|
||||
# "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
||||
# "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||
}
|
||||
|
||||
ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
||||
class BedrockModel(BaseChatModel):
|
||||
@@ -439,10 +443,25 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
texts = []
|
||||
if isinstance(embeddings_request.input, str):
|
||||
texts = [embeddings_request.input]
|
||||
elif isinstance(embeddings_request.input, list):
|
||||
texts = embeddings_request.input
|
||||
elif isinstance(embeddings_request.input, Iterable):
|
||||
# For encoded input
|
||||
# The workaround is to use tiktoken to decode to get the original text.
|
||||
encodings = []
|
||||
for inner in embeddings_request.input:
|
||||
if isinstance(inner, int):
|
||||
# Iterable[int]
|
||||
encodings.append(inner)
|
||||
else:
|
||||
# Iterable[Iterable[int]]
|
||||
text = ENCODER.decode(list(inner))
|
||||
texts.append(text)
|
||||
if encodings:
|
||||
texts.append(ENCODER.decode(encodings))
|
||||
|
||||
# Maximum of 2048 characters
|
||||
args = {
|
||||
|
||||
Reference in New Issue
Block a user