Initial commit

This commit is contained in:
Aiden Dai
2024-03-27 15:20:24 +08:00
parent f77df2c536
commit f974cb2728
21 changed files with 2149 additions and 5 deletions

9
src/Dockerfile Normal file
View File

@@ -0,0 +1,9 @@
FROM public.ecr.aws/lambda/python:3.12
COPY ./api ./api
COPY requirements.txt .
RUN pip3 install -r requirements.txt -U --no-cache-dir
CMD [ "api.app.handler" ]

0
src/api/__init__.py Normal file
View File

52
src/api/app.py Normal file
View File

@@ -0,0 +1,52 @@
import logging
import uvicorn
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from mangum import Mangum
from api.routers import model, chat
from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION
config = {
"title": TITLE,
"description": DESCRIPTION,
"summary": SUMMARY,
"version": VERSION,
}
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
app = FastAPI(**config)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(model.router, prefix=API_ROUTE_PREFIX)
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
@app.get("/health")
async def health():
"""For health check if needed"""
return {"status": "OK"}
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return PlainTextResponse(str(exc), status_code=400)
handler = Mangum(app)
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)

28
src/api/auth.py Normal file
View File

@@ -0,0 +1,28 @@
import os
from typing import Annotated
import boto3
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from api.setting import DEFAULT_API_KEYS
api_key_param = os.environ.get("API_KEY_PARAM_NAME")
if api_key_param:
ssm = boto3.client("ssm")
api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"][
"Value"
]
else:
api_key = DEFAULT_API_KEYS
security = HTTPBearer()
def api_key_auth(
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]
):
if credentials.credentials != api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key"
)

View File

@@ -0,0 +1 @@
from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model

391
src/api/models/bedrock.py Normal file
View File

@@ -0,0 +1,391 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
import boto3
from api.schema import (
ChatResponse,
ChatRequest,
ChatRequestMessage,
Choice,
ChatResponseMessage,
Usage,
ChatStreamResponse,
ChoiceDelta,
)
from api.setting import DEBUG, AWS_REGION
logger = logging.getLogger(__name__)
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
region_name=AWS_REGION,
)
SUPPORTED_BEDROCK_MODELS = {
"anthropic.claude-instant-v1": "Claude Instant",
"anthropic.claude-v2:1": "Claude",
"anthropic.claude-v2": "Claude",
"anthropic.claude-3-sonnet-20240229-v1:0": "Claude 3 Sonnet",
"anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
"meta.llama2-13b-chat-v1": "Llama 2 Chat 13B",
"meta.llama2-70b-chat-v1": "Llama 2 Chat 70B",
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
}
class BaseChatModel(ABC):
"""Represent a basic chat model
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
class BedrockModel(BaseChatModel):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
def _create_response(
self,
model: str,
message: str,
message_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
choice = Choice(
index=0,
message=ChatResponseMessage(
role="assistant",
content=message,
),
finish_reason="stop",
)
response = ChatResponse(
id=message_id,
model=model,
choices=[choice],
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
def _create_response_stream(
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
) -> ChatStreamResponse:
choice = ChoiceDelta(
index=0,
delta=ChatResponseMessage(
role="assistant",
content=chunk_message,
),
finish_reason=finish_reason,
)
response = ChatStreamResponse(
id=message_id,
model=model,
choices=[choice],
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
def get_model(model_id: str) -> BedrockModel:
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]:
return ClaudeModel()
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
return Llama2Model()
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]:
return MistralModel()
else:
logger.error("Unsupported model id " + model_id)
raise ValueError("Invalid model ID")
class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31"
def _parse_args(self, chat_request: ChatRequest) -> dict:
args = {
"anthropic_version": self.anthropic_version,
"max_tokens": chat_request.max_tokens,
"top_p": chat_request.top_p,
"temperature": chat_request.temperature,
}
if chat_request.messages[0].role == "system":
args["system"] = chat_request.messages[0].content
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages[1:]
]
else:
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages
]
return args
def chat(self, chat_request: ChatRequest) -> ChatResponse:
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
model=chat_request.model,
message=response_body["content"][0]["text"],
message_id=response_body["id"],
input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
if chunk["type"] == "message_start":
msg_id = chunk["message"]["id"]
continue
if chunk["type"] == "message_delta":
chunk_message = ""
finish_reason = "stop"
elif chunk["type"] == "content_block_delta":
chunk_message = chunk["delta"]["text"]
finish_reason = None
else:
continue
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk_message,
finish_reason=finish_reason,
)
yield self._stream_response_to_bytes(response)
class Llama2Model(BedrockModel):
def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str:
"""Create a prompt message follow below example:
<s>[INST] <<SYS>>\n{your_system_message}\n<</SYS>>\n\n{user_message_1} [/INST] {model_reply_1}</s>
<s>[INST] {user_message_2} [/INST]
"""
if DEBUG:
logger.info("Convert below messages to prompt for Llama 2: ")
for msg in messages:
logger.info(msg.model_dump_json())
bos_token = "<s>"
eos_token = "</s>"
prompt = bos_token + "[INST] "
start = 0
end_turn = False
if messages[0].role == "system":
prompt += "<<SYS>>\n" + messages[0].content + "\n<<SYS>>\n\n"
start = 1
# TODO: Add validation
for i in range(start, len(messages)):
msg = messages[i]
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "
prompt += msg.content + " [/INST] "
end_turn = False
else:
prompt += msg.content + eos_token
end_turn = True
if DEBUG:
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
prompt = self._convert_prompt(chat_request.messages)
return {
"prompt": prompt,
"max_gen_len": chat_request.max_tokens,
"temperature": chat_request.temperature,
"top_p": chat_request.top_p,
}
def chat(self, chat_request: ChatRequest) -> ChatResponse:
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message_id = self._generate_message_id()
return self._create_response(
model=chat_request.model,
message=response_body["generation"],
message_id=message_id,
input_tokens=response_body["prompt_token_count"],
output_tokens=response_body["generation_token_count"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk["generation"],
finish_reason=chunk["stop_reason"],
)
yield self._stream_response_to_bytes(response)
class MistralModel(BedrockModel):
def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str:
"""Create a prompt message follow below example:
<s>[INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1}</s>
<s>[INST] {user_message_2} [/INST]
"""
if DEBUG:
logger.info("Convert below messages to prompt for Llama 2: ")
for msg in messages:
logger.info(msg.model_dump_json())
bos_token = "<s>"
eos_token = "</s>"
prompt = bos_token + "[INST] "
start = 0
end_turn = False
if messages[0].role == "system":
prompt += messages[0].content + "\n"
start = 1
# TODO: Add validation
for i in range(start, len(messages)):
msg = messages[i]
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "
prompt += msg.content + " [/INST] "
end_turn = False
else:
prompt += msg.content + eos_token
end_turn = True
if DEBUG:
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
prompt = self._convert_prompt(chat_request.messages)
return {
"prompt": prompt,
"max_tokens": chat_request.max_tokens,
"temperature": chat_request.temperature,
"top_p": chat_request.top_p,
}
def chat(self, chat_request: ChatRequest) -> ChatResponse:
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message_id = self._generate_message_id()
return self._create_response(
model=chat_request.model,
message=response_body["outputs"][0]["text"],
message_id=message_id,
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk["outputs"][0]["text"],
finish_reason=chunk["outputs"][0]["stop_reason"],
)
yield self._stream_response_to_bytes(response)

View File

51
src/api/routers/chat.py Normal file
View File

@@ -0,0 +1,51 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Body, HTTPException
from fastapi.responses import StreamingResponse
from api.auth import api_key_auth
from api.models import get_model, SUPPORTED_BEDROCK_MODELS
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL
router = APIRouter()
router = APIRouter(
prefix="/chat",
tags=["items"],
dependencies=[Depends(api_key_auth)],
# responses={404: {"description": "Not found"}},
)
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse)
async def chat_completions(
chat_request: Annotated[
ChatRequest,
Body(
examples=[
{
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
],
),
]
):
if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL
if chat_request.model not in SUPPORTED_BEDROCK_MODELS.keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id " + chat_request.model)
try:
model = get_model(chat_request.model)
if chat_request.stream:
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"
)
return model.chat(chat_request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

41
src/api/routers/model.py Normal file
View File

@@ -0,0 +1,41 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path
from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS
from api.schema import Models, Model
router = APIRouter()
router = APIRouter(
prefix="/models",
tags=["items"],
dependencies=[Depends(api_key_auth)],
# responses={404: {"description": "Not found"}},
)
async def validate_model_id(model_id: str):
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id")
@router.get("/", response_model=Models)
async def list_models():
model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()]
return Models(data=model_list)
@router.get(
"/{model_id}",
response_model=Model,
)
async def get_model(
model_id: Annotated[
str,
Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"),
]
):
await validate_model_id(model_id)
return Model(id=model_id)

80
src/api/schema.py Normal file
View File

@@ -0,0 +1,80 @@
import time
from typing import Literal
from pydantic import BaseModel, Field
class Model(BaseModel):
id: str
created: int = Field(default_factory=lambda: int(time.time()))
object: str | None = "model"
owned_by: str | None = "bedrock"
class Models(BaseModel):
object: str | None = "list"
data: list[Model] = []
class ChatRequestMessage(BaseModel):
name: str | None = None
role: Literal["user", "assistant", "system"]
content: str
class ChatRequest(BaseModel):
messages: list[ChatRequestMessage]
model: str
frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
stream: bool | None = False
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
user: str | None = None # Not used
max_tokens: int | None = 2048
n: int | None = 1 # Not used
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatResponseMessage(BaseModel):
# tool_calls
role: Literal["assistant"] | None = None
content: str | None = None
class BaseChoice(BaseModel):
index: int
finish_reason: str | None
logprobs: dict | None = None
class Choice(BaseChoice):
message: ChatResponseMessage
class ChoiceDelta(BaseChoice):
delta: ChatResponseMessage
class BaseChatResponse(BaseModel):
# id: str = Field(default_factory=lambda: "chatcmpl-" + str(uuid.uuid4())[:8])
id: str
created: int = Field(default_factory=lambda: int(time.time()))
model: str
system_fingerprint: str = "fp_e97c09dd4e26"
class ChatResponse(BaseChatResponse):
choices: list[Choice]
object: Literal["chat.completion"] = "chat.completion"
usage: Usage
class ChatStreamResponse(BaseChatResponse):
choices: list[ChoiceDelta]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"

27
src/api/setting.py Normal file
View File

@@ -0,0 +1,27 @@
import os
DEFAULT_API_KEYS = "bedrock"
API_ROUTE_PREFIX = "/api/v1"
TITLE = "Amazon Bedrock Proxy APIs"
SUMMARY = "OpenAI-Compatible RESTful APIs for Amazon Bedrock"
VERSION = "0.1.0"
DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported:
- anthropic.claude-instant-v1
- anthropic.claude-v2:1
- anthropic.claude-v2
- anthropic.claude-3-sonnet-20240229-v1:0
- anthropic.claude-3-haiku-20240307-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1
"""
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")

4
src/requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
fastapi==0.103.0
pydantic==2.6.3
uvicorn==0.27.0.post1
mangum==0.17.0