diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..364a711 --- /dev/null +++ b/.flake8 @@ -0,0 +1,19 @@ +[flake8] +max-line-length = 120 +ignore = + E203,W191,W503 +exclude = + build + .git + __pycache__ + .tox + venv + .venv + .venv-test + tmp* + deployment + cdk.out + node_modules + +max-complexity = 10 +require-code = True \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d8b355e --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +Config \ No newline at end of file diff --git a/README.md b/README.md index 7f92204..22163db 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,245 @@ -## My Project +[中文](./README_CN.md) -TODO: Fill this README out! +# Bedrock Access Gateway -Be sure to: +OpenAI-Compatible RESTful APIs for Amazon Bedrock -* Change the title in this README -* Edit your repository description on GitHub +## Overview + +Amazon Bedrock offers a wide range of foundation models (such as Claude 3 Sonnet/Haiku, Llama 2, Mistral/Mixtral etc.) +and a broad set of capabilities for you to build generative AI applications. +Check [Amazon Bedrock](https://aws.amazon.com/bedrock) for more details. + +Sometimes, you might have applications developed using OpenAI APIs or SDKs, and you want to experiment with Amazon +Bedrock without modifying your codebase. Or you may simply wish to evaluate the capabilities of these foundation models +in tools like AutoGen etc. Well, this repository allows you to access Amazon Bedrock models seamlessly through OpenAI +APIs and SDKs, enabling you to test these models without code changes. + +If you find this GitHub repository useful, please consider giving it a free star to show your appreciation and support +for the project. + +Features: + +- [x] Support streaming response via server-sent events (SSE) +- [x] Support Model APIs +- [x] Support Chat Completion APIs +- [ ] Support Function Call/Tool Call +- [ ] Support Embedding APIs +- [ ] Support Image APIs + +> NOTE: 1. The legacy [text completion](https://platform.openai.com/docs/api-reference/completions) API is not +> supported, you should move to chat completion API. 2. May support other APIs such as fine-tuning, Assistants API etc. +> in the future. + +Supported Amazon Bedrock models (Model IDs): + +- anthropic.claude-instant-v1 +- anthropic.claude-v2:1 +- anthropic.claude-v2 +- anthropic.claude-3-sonnet-20240229-v1:0 +- anthropic.claude-3-haiku-20240307-v1:0 +- meta.llama2-13b-chat-v1 +- meta.llama2-70b-chat-v1 +- mistral.mistral-7b-instruct-v0:2 +- mistral.mixtral-8x7b-instruct-v0:1 + +> Note: The default model is set to `anthropic.claude-3-sonnet-20240229-v1:0`. You can change it via Lambda environment +> variables. + +## Get Started + +### Prerequisites + +Please make sure you have met below prerequisites: + +- Access to Amazon Bedrock foundation models. + +If you haven't got model access, please follow +the [Set Up](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html) guide + +### Architecture + +The following diagram illustrates the solution architecture. Note that it also includes a new **VPC** with two public +subnets only for the Application Load Balancer (ALB). + +![Architecture](assets/arch.svg) + +### Deployment + +Please follow below steps to deploy the Bedrock Proxy APIs into your AWS account. Only support regions where Amazon +Bedrock is available (such as us-west-2). The deployment will take approximately 3-5 minutes. + +**Step 1: Create you own custom API key (Optional)** + +> NOTE: This step is to use any string (without spaces) you like to create a custom API Key (credential) that will be +> used to access the proxy API later. This key does not have to match your actual OpenAI key, and you don't even need to +> have an OpenAI API key. It is recommended that you take this step and ensure that you keep the key safe and private. + +1. Open the AWS Management Console and navigate to the Systems Manager service. +2. In the left-hand navigation pane, click on "Parameter Store". +3. Click on the "Create parameter" button. +4. In the "Create parameter" window, select the following options: + - Name: Enter a descriptive name for your parameter (e.g., "BedrockProxyAPIKey"). + - Description: Optionally, provide a description for the parameter. + - Tier: Select **Standard**. + - Type: Select **SecureString**. + - Value: Any string (without spaces). +5. Click "Create parameter". +6. Make a note of the parameter name you used (e.g., "BedrockProxyAPIKey"). You'll need this in the next step. + +**Step 2: Deploy the CloudFormation stack** + +1. Sign in to AWS Management Console, switch to the region to deploy the CloudFormation Stack to. +2. Click the following button to launch the CloudFormation Stack in that region. + + [![Launch Stack](assets/launch-stack.png)](https://console.aws.amazon.com/cloudformation/home#/stacks/create/template?stackName=BedrockProxyAPI&templateURL=https://aws-gcr-solutions.s3.amazonaws.com/bedrock-proxy-api/latest/BedrockProxy.template) + +3. Click "Next". +4. On the "Specify stack details" page, provide the following information: + - Stack name: Change the stack name if needed. + - ApiKeyParam (if you set up an API key in Step 1): Enter the parameter name you used for storing the API key ( + e.g., "BedrockProxyAPIKey"). If you did not set up an API key, leave this field blank. + Click "Next". +5. On the "Configure stack options" page, you can leave the default settings or customize them according to your needs. +6. Click "Next". +7. On the "Review" page, review the details of the stack you're about to create. Check the "I acknowledge that AWS + CloudFormation might create IAM resources" checkbox at the bottom. +8. Click "Create stack". + +That is it! Once deployed, click the CloudFormation stack and go to **Outputs** tab, you can find the API Base URL +from `APIBaseUrl`, the value should look like `http://xxxx.xxx.elb.amazonaws.com/api/v1`. + +### SDK/API Usage + +All you need is the API Key and the API Base URL. And if you didn't +set up your own key, then the default API Key `bedrock` will be used. + +Now, you can try out the proxy APIs. Let's say you want to test Claude 3 Sonnet model, then +use `anthropic.claude-3-sonnet-20240229-v1:0` as the Model ID. + +- **Example API Usage** + +```bash +curl https:///chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{ + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` + +- **Example SDK Usage** + +```bash +export OPENAI_API_KEY= +export OPENAI_API_BASE= +``` + +```python +from openai import OpenAI + +client = OpenAI() +completion = client.chat.completions.create( + model="anthropic.claude-3-sonnet-20240229-v1:0", + messages=[{"role": "user", "content": "Hello!"}], +) + +print(completion.choices[0].message.content) +``` + +## Other Examples + +### AutoGen + +Below is an image of setting up the model in AutoGen studio. + +![AutoGen Model](assets/autogen-model.png) + +### LangChain + +Make sure you use `ChatOpenAI(...)` instead of `OpenAI(...)` + +```python +# pip install langchain-openai +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain_openai import ChatOpenAI + +chat = ChatOpenAI( + model="anthropic.claude-3-sonnet-20240229-v1:0", + temperature=0, + openai_api_key="xxxx", + openai_api_base="http://xxx.elb.amazonaws.com/api/v1", +) + +template = """Question: {question} + +Answer: Let's think step by step.""" + +prompt = PromptTemplate.from_template(template) +llm_chain = LLMChain(prompt=prompt, llm=chat) + +question = "What NFL team won the Super Bowl in the year Justin Beiber was born?" +response = llm_chain.invoke(question) +print(response) + +``` + +## FAQs + +### About Privacy + +This application does not collect any of your data. Furthermore, it does not log any requests or responses by default. + +### Why not used API Gateway instead of Application Load Balancer? + +Short answer is that API Gateway does not support server-sent events (SSE) for streaming response. + +### Which regions are supported? + +This solution only supports the regions where Amazon Bedrock is available, so: + +- US East (N. Virginia) +- US West (Oregon) +- Asia Pacific (Singapore) +- Asia Pacific (Tokyo) +- Europe (Frankfurt) + +Note that not all models are available in those regions. + +### Can I build and use my own ECR image + +Yes, you can clone the repo and build the container image by yourself (src/Dockerfile) and then push to your ECR repo. + +Replace the repo url in the CloudFormation template before you deploy. + +### Can I run this locally + +Yes, you can run this locally, then the API base url should be like `http://localhost:8000/api/v1` + +### Any performance sacrifice or latency increase by using the proxy APIs + +This is yet to be tested. But you should use this solution for PoC only. + +### Any plan to support SageMaker models? + +Currently, there is no plan of supporting SageMaker models. This depends on if there are customer asks. + +### Any plan to support Bedrock custom models? + +Fine-tuned models and models with Provisioned Throughput are not supported. You can clone the repo and make the +customization if needed. + +### How to upgrade? + +If there is no changes on architecture, you can simply deploy the latest image to your Lambda to use the new +features (manually) without redeploying the whole CloudFormation stack. ## Security diff --git a/README_CN.md b/README_CN.md new file mode 100644 index 0000000..3d36451 --- /dev/null +++ b/README_CN.md @@ -0,0 +1,236 @@ +[English](./README.md) + +# Bedrock Access Gateway + +使用兼容OpenAI的API访问Amazon Bedrock + +## 概述 + +Amazon Bedrock提供了广泛的基础模型(如Claude 3 Sonnet/Haiku、Llama 2、Mistral/Mixtral等) +,以及构建生成式AI应用程序的多种功能。更多详细信息,请查看[Amazon Bedrock](https://aws.amazon.com/bedrock)。 + +有时,您可能已经使用OpenAI的API或SDK构建了应用程序,并希望在不修改代码的情况下试用Amazon +Bedrock的模型。或者,您可能只是希望在AutoGen等工具中评估这些基础模型的功能。 好消息是, 这里提供了一种方便的途径,让您可以使用 +OpenAI 的 API 或 SDK 无缝集成并试用 Amazon Bedrock 的模型,而无需对现有代码进行修改。 + +如果您觉得这个项目有用,请考虑给它点个一个免费的小星星。 + +功能列表: + +- [x] 支持 server-sent events (SSE)的流式响应 +- [x] 支持 Model APIs +- [x] 支持 Chat Completion APIs +- [ ] 支持 Function Call/Tool Call +- [ ] 支持 Embedding APIs +- [ ] 支持 Image APIs + +> 注意: 1,不支持旧的 [text completion](https://platform.openai.com/docs/api-reference/completions) API,请更改为使用Chat +> Completion API。 2.未来可能支持其他API, 如Fine-tune、Assistants API等。 + +支持的Amazon Bedrock模型列表(Model IDs): + +- anthropic.claude-instant-v1 +- anthropic.claude-v2:1 +- anthropic.claude-v2 +- anthropic.claude-3-sonnet-20240229-v1:0 +- anthropic.claude-3-haiku-20240307-v1:0 +- meta.llama2-13b-chat-v1 +- meta.llama2-70b-chat-v1 +- mistral.mistral-7b-instruct-v0:2 +- mistral.mixtral-8x7b-instruct-v0:1 + +> Note: 默认模型为 `anthropic.claude-3-sonnet-20240229-v1:0`, 可以通过更改Lambda环境变量进行更改。 + +## 使用指南 + +### 前提条件 + +请确保您已满足以下先决条件: + +- 可以访问Amazon Bedrock基础模型。 + +如果您还没有获得模型访问权限,请参考[配置](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html)指南。 + +### 架构图 + +下图展示了本方案的架构。请注意,它还包括一个新的**VPC**,其中只有两个公共子网用于应用程序负载均衡器(ALB)。 + +![Architecture](assets/arch.svg) + +### 部署 + +请按以下步骤将Bedrock代理API部署到您的AWS账户中。仅支持Amazon Bedrock可用的区域(如us-west-2)。 + +**第一步: 自定义您的API Key (可选)** + +> 注意:这一步是使用任意字符串(不带空格)创建一个自定义的API Key(凭证),将用于后续访问代理API。此API Key不必与您实际的OpenAI +> Key一致,您甚至无需拥有OpenAI API Key。建议您执行此步操作并且请确保保管好此API Key。 + +1. 打开AWS管理控制台,导航到Systems Manager服务。 +2. 在左侧导航窗格中,单击"参数存储"。 +3. 单击"创建参数"按钮。 +4. 在"创建参数"窗口中,选择以下选项: + - 名称:输入参数的描述性名称(例如"BedrockProxyAPIKey")。 + - 描述:可选,为参数提供描述。 + - 层级:选择**标准**。 + - 类型:选择**SecureString**。 + - 值: 随意字符串(不带空格)。 +5. 单击"创建参数"。 +6. 记录您使用的参数名称(例如"BedrockProxyAPIKey")。您将在下一步中需要它。 + +**第二步: 部署CloudFormation堆栈** + +1. 登录AWS管理控制台,切换到要部署CloudFormation堆栈的区域。 +2. 单击以下按钮在该区域启动CloudFormation堆栈。 + + [![Launch Stack](assets/launch-stack.png)](https://console.aws.amazon.com/cloudformation/home#/stacks/create/template?stackName=BedrockProxyAPI&templateURL=https://aws-gcr-solutions.s3.amazonaws.com/bedrock-proxy-api/latest/BedrockProxy.template) + +3. 单击"下一步"。 +4. 在"指定堆栈详细信息"页面,提供以下信息: + - 堆栈名称: 可以根据需要更改名称。 + - ApiKeyParam(如果在步骤1中设置了API密钥):输入您用于存储API密钥的参数名称(例如"BedrockProxyAPIKey") + 。如果您没有设置API密钥,请将此字段留空。 + 单击"下一步"。 +5. 在"配置堆栈选项"页面,您可以保留默认设置或根据需要进行自定义。 +6. 单击"下一步"。 +7. 在"审核"页面,查看您即将创建的堆栈详细信息。勾选底部的"我确认,AWS CloudFormation 可能创建 IAM 资源。"复选框。 +8. 单击"创建堆栈"。 + +仅此而已。部署完成后,点击CloudFormation堆栈,进入"输出"选项卡,你可以从"APIBaseUrl" +中找到API Base URL,它应该类似于`http://xxxx.xxx.elb.amazonaws.com/api/v1` 这样的格式。 + +### SDK/API使用 + +你只需要API Key和API Base URL。如果你没有设置自己的密钥,那么默认将使用API Key `bedrock`。 + +现在,你可以尝试使用代理API了。假设你想测试Claude 3 Sonnet模型,那么使用"anthropic.claude-3-sonnet-20240229-v1:0"作为模型ID。 + +- **API 使用示例** + +```bash +curl https:///chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{ + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` + +- **SDK 使用示例** + +```bash +export OPENAI_API_KEY= +export OPENAI_API_BASE= +``` + +```python +from openai import OpenAI + +client = OpenAI() +completion = client.chat.completions.create( + model="anthropic.claude-3-sonnet-20240229-v1:0", + messages=[{"role": "user", "content": "Hello!"}], +) + +print(completion.choices[0].message.content) +``` + +## 其他例子 + +### AutoGen + +例如在AutoGen studio配置和使用模型 + +![AutoGen Model](assets/autogen-model.png) + +### LangChain + +请确保使用的示`ChatOpenAI(...)` ,而不是`OpenAI(...)` + +```python +# pip install langchain-openai +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain_openai import ChatOpenAI + +chat = ChatOpenAI( + model="anthropic.claude-3-sonnet-20240229-v1:0", + temperature=0, + openai_api_key="xxxx", + openai_api_base="http://xxx.elb.amazonaws.com/api/v1", +) + +template = """Question: {question} + +Answer: Let's think step by step.""" + +prompt = PromptTemplate.from_template(template) +llm_chain = LLMChain(prompt=prompt, llm=chat) + +question = "What NFL team won the Super Bowl in the year Justin Beiber was born?" +response = llm_chain.invoke(question) +print(response) + +``` + +## FAQs + +### 关于隐私 + +这个方案不会收集您的任何数据。而且,它默认情况下也不会记录任何请求或响应。 + +### 为什么没有使用API Gateway 而是使用了Application Load Balancer? + +简单的答案是API Gateway不支持 server-sent events (SSE) 用于流式响应。 + +### 支持哪些区域? + +只支持Amazon Bedrock可用的区域,即: + +- 美国东部(弗吉尼亚北部) +- 美国西部(俄勒冈州) +- 亚太地区(新加坡) +- 亚太地区(东京) +- 欧洲(法兰克福) + +注意,并非所有模型都在上面区可用。 + +### 我可以构建并使用自己的ECR镜像吗? + +是的,你可以克隆repo并自行构建容器镜像(src/Dockerfile),然后推送到你自己的ECR仓库。 + +在部署之前,请在CloudFormation模板中替换镜像仓库URL。 + +### 我可以在本地运行吗? + +是的,你可以在本地运行,那么API Base URL应该类似于`http://localhost:8000/api/v1` + +### 使用代理API会有任何性能牺牲或延迟增加吗? + +这还有待测试。但你应该只将此解决方案用于概念验证。 + +### 有计划支持SageMaker模型吗? + +目前没有支持SageMaker模型的计划。这取决于是否有客户需求。 + +### 有计划支持Bedrock自定义模型吗? + +不支持微调模型和设置了已预配吞吐量的模型。如有需要,你可以克隆repo并进行自定义。 + +### 如何升级? + +如果架构没有变化,你可以简单地将最新镜像部署到Lambda中,以使用新功能(手动),而无需重新部署整个CloudFormation堆栈。 + +## 安全 + +更多信息,请参阅[CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications)。 + +## 许可证 + +本项目根据MIT-0许可证获得许可。请参阅LICENSE文件。 diff --git a/assets/arch.svg b/assets/arch.svg new file mode 100644 index 0000000..a872ef3 --- /dev/null +++ b/assets/arch.svg @@ -0,0 +1,4 @@ + + + +
AWS Cloud
AWS Cloud
Client / SDK
Client / SDK
Application Load Balancer
Application L...
Optional
Optional
AWS Lambda
(Proxy)
AWS Lambda...
Amazon Bedrock
Amazon Bedrock
Parameter Store
(API Key)
Parameter St...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/assets/autogen-agent.png b/assets/autogen-agent.png new file mode 100644 index 0000000..823e1dc Binary files /dev/null and b/assets/autogen-agent.png differ diff --git a/assets/autogen-model.png b/assets/autogen-model.png new file mode 100644 index 0000000..bbc8f91 Binary files /dev/null and b/assets/autogen-model.png differ diff --git a/assets/launch-stack.png b/assets/launch-stack.png new file mode 100644 index 0000000..2745adf Binary files /dev/null and b/assets/launch-stack.png differ diff --git a/deployment/BedrockProxy.template b/deployment/BedrockProxy.template new file mode 100644 index 0000000..0ec4e23 --- /dev/null +++ b/deployment/BedrockProxy.template @@ -0,0 +1,805 @@ +{ + "Parameters": { + "ApiKeyParam": { + "Type": "String", + "Default": "", + "Description": "The parameter name in System Manager used to store the API Key, leave blank to use a default key" + } + }, + "Resources": { + "VPCB9E5F0B4": { + "Type": "AWS::EC2::VPC", + "Properties": { + "CidrBlock": "10.250.0.0/16", + "EnableDnsHostnames": true, + "EnableDnsSupport": true, + "InstanceTenancy": "default", + "Tags": [ + { + "Key": "Name", + "Value": "BedrockProxy/VPC" + } + ] + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/Resource" + } + }, + "VPCPublicSubnet1SubnetB4246D30": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "AvailabilityZone": { + "Fn::Select": [ + 0, + { + "Fn::GetAZs": "" + } + ] + }, + "CidrBlock": "10.250.0.0/24", + "MapPublicIpOnLaunch": true, + "Tags": [ + { + "Key": "aws-cdk:subnet-name", + "Value": "Public" + }, + { + "Key": "aws-cdk:subnet-type", + "Value": "Public" + }, + { + "Key": "Name", + "Value": "BedrockProxy/VPC/PublicSubnet1" + } + ], + "VpcId": { + "Ref": "VPCB9E5F0B4" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet1/Subnet" + } + }, + "VPCPublicSubnet1RouteTableFEE4B781": { + "Type": "AWS::EC2::RouteTable", + "Properties": { + "Tags": [ + { + "Key": "Name", + "Value": "BedrockProxy/VPC/PublicSubnet1" + } + ], + "VpcId": { + "Ref": "VPCB9E5F0B4" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet1/RouteTable" + } + }, + "VPCPublicSubnet1RouteTableAssociation0B0896DC": { + "Type": "AWS::EC2::SubnetRouteTableAssociation", + "Properties": { + "RouteTableId": { + "Ref": "VPCPublicSubnet1RouteTableFEE4B781" + }, + "SubnetId": { + "Ref": "VPCPublicSubnet1SubnetB4246D30" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet1/RouteTableAssociation" + } + }, + "VPCPublicSubnet1DefaultRoute91CEF279": { + "Type": "AWS::EC2::Route", + "Properties": { + "DestinationCidrBlock": "0.0.0.0/0", + "GatewayId": { + "Ref": "VPCIGWB7E252D3" + }, + "RouteTableId": { + "Ref": "VPCPublicSubnet1RouteTableFEE4B781" + } + }, + "DependsOn": [ + "VPCVPCGW99B986DC" + ], + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet1/DefaultRoute" + } + }, + "VPCPublicSubnet2Subnet74179F39": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "AvailabilityZone": { + "Fn::Select": [ + 1, + { + "Fn::GetAZs": "" + } + ] + }, + "CidrBlock": "10.250.1.0/24", + "MapPublicIpOnLaunch": true, + "Tags": [ + { + "Key": "aws-cdk:subnet-name", + "Value": "Public" + }, + { + "Key": "aws-cdk:subnet-type", + "Value": "Public" + }, + { + "Key": "Name", + "Value": "BedrockProxy/VPC/PublicSubnet2" + } + ], + "VpcId": { + "Ref": "VPCB9E5F0B4" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet2/Subnet" + } + }, + "VPCPublicSubnet2RouteTable6F1A15F1": { + "Type": "AWS::EC2::RouteTable", + "Properties": { + "Tags": [ + { + "Key": "Name", + "Value": "BedrockProxy/VPC/PublicSubnet2" + } + ], + "VpcId": { + "Ref": "VPCB9E5F0B4" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet2/RouteTable" + } + }, + "VPCPublicSubnet2RouteTableAssociation5A808732": { + "Type": "AWS::EC2::SubnetRouteTableAssociation", + "Properties": { + "RouteTableId": { + "Ref": "VPCPublicSubnet2RouteTable6F1A15F1" + }, + "SubnetId": { + "Ref": "VPCPublicSubnet2Subnet74179F39" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet2/RouteTableAssociation" + } + }, + "VPCPublicSubnet2DefaultRouteB7481BBA": { + "Type": "AWS::EC2::Route", + "Properties": { + "DestinationCidrBlock": "0.0.0.0/0", + "GatewayId": { + "Ref": "VPCIGWB7E252D3" + }, + "RouteTableId": { + "Ref": "VPCPublicSubnet2RouteTable6F1A15F1" + } + }, + "DependsOn": [ + "VPCVPCGW99B986DC" + ], + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/PublicSubnet2/DefaultRoute" + } + }, + "VPCIGWB7E252D3": { + "Type": "AWS::EC2::InternetGateway", + "Properties": { + "Tags": [ + { + "Key": "Name", + "Value": "BedrockProxy/VPC" + } + ] + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/IGW" + } + }, + "VPCVPCGW99B986DC": { + "Type": "AWS::EC2::VPCGatewayAttachment", + "Properties": { + "InternetGatewayId": { + "Ref": "VPCIGWB7E252D3" + }, + "VpcId": { + "Ref": "VPCB9E5F0B4" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/VPC/VPCGW" + } + }, + "ProxyApiHandlerServiceRoleBE71BFB1": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": { + "Service": "lambda.amazonaws.com" + } + } + ], + "Version": "2012-10-17" + }, + "ManagedPolicyArns": [ + { + "Fn::Join": [ + "", + [ + "arn:", + { + "Ref": "AWS::Partition" + }, + ":iam::aws:policy/service-role/AWSLambdaBasicExecutionRole" + ] + ] + } + ] + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ApiHandler/ServiceRole/Resource" + } + }, + "ProxyApiHandlerServiceRoleDefaultPolicy86681202": { + "Type": "AWS::IAM::Policy", + "Properties": { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream" + ], + "Effect": "Allow", + "Resource": "arn:aws:bedrock:*::foundation-model/*" + }, + { + "Action": [ + "ssm:DescribeParameters", + "ssm:GetParameters", + "ssm:GetParameter", + "ssm:GetParameterHistory" + ], + "Effect": "Allow", + "Resource": { + "Fn::Join": [ + "", + [ + "arn:", + { + "Ref": "AWS::Partition" + }, + ":ssm:", + { + "Ref": "AWS::Region" + }, + ":", + { + "Ref": "AWS::AccountId" + }, + ":parameter/", + { + "Ref": "ApiKeyParam" + } + ] + ] + } + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "ProxyApiHandlerServiceRoleDefaultPolicy86681202", + "Roles": [ + { + "Ref": "ProxyApiHandlerServiceRoleBE71BFB1" + } + ] + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ApiHandler/ServiceRole/DefaultPolicy/Resource" + } + }, + "ProxyApiHandlerEC15A492": { + "Type": "AWS::Lambda::Function", + "Properties": { + "Architectures": [ + "arm64" + ], + "Code": { + "ImageUri": { + "Fn::Join": [ + "", + [ + { + "Fn::Select": [ + 4, + { + "Fn::Split": [ + ":", + { + "Fn::FindInMap": [ + "ProxyRegionTable03E5BEB3", + { + "Ref": "AWS::Region" + }, + "repoArn" + ] + } + ] + } + ] + }, + ".dkr.ecr.", + { + "Fn::Select": [ + 3, + { + "Fn::Split": [ + ":", + { + "Fn::FindInMap": [ + "ProxyRegionTable03E5BEB3", + { + "Ref": "AWS::Region" + }, + "repoArn" + ] + } + ] + } + ] + }, + ".", + { + "Ref": "AWS::URLSuffix" + }, + "/bedrock-proxy-api:latest" + ] + ] + } + }, + "Description": "Bedrock Proxy API Handler", + "Environment": { + "Variables": { + "API_KEY_PARAM_NAME": { + "Ref": "ApiKeyParam" + }, + "DEBUG": "false", + "DEFAULT_MODEL": { + "Fn::FindInMap": [ + "ProxyRegionTable03E5BEB3", + { + "Ref": "AWS::Region" + }, + "model" + ] + } + } + }, + "MemorySize": 1024, + "PackageType": "Image", + "Role": { + "Fn::GetAtt": [ + "ProxyApiHandlerServiceRoleBE71BFB1", + "Arn" + ] + }, + "Timeout": 300 + }, + "DependsOn": [ + "ProxyApiHandlerServiceRoleDefaultPolicy86681202", + "ProxyApiHandlerServiceRoleBE71BFB1" + ], + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ApiHandler/Resource" + } + }, + "ProxyApiHandlerInvoke2UTWxhlfyqbT5FTn5jvgbLgjFfJwzswGk55DU1HYF6C33779": { + "Type": "AWS::Lambda::Permission", + "Properties": { + "Action": "lambda:InvokeFunction", + "FunctionName": { + "Fn::GetAtt": [ + "ProxyApiHandlerEC15A492", + "Arn" + ] + }, + "Principal": "elasticloadbalancing.amazonaws.com" + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ApiHandler/Invoke2UTWxhlfyqbT5FTn--5jvgbLgj+FfJwzswGk55DU1H--Y=" + } + }, + "ProxyALB87756780": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "LoadBalancerAttributes": [ + { + "Key": "deletion_protection.enabled", + "Value": "false" + } + ], + "Scheme": "internet-facing", + "SecurityGroups": [ + { + "Fn::GetAtt": [ + "ProxyALBSecurityGroup0D6CA3DA", + "GroupId" + ] + } + ], + "Subnets": [ + { + "Ref": "VPCPublicSubnet1SubnetB4246D30" + }, + { + "Ref": "VPCPublicSubnet2Subnet74179F39" + } + ], + "Type": "application" + }, + "DependsOn": [ + "VPCPublicSubnet1DefaultRoute91CEF279", + "VPCPublicSubnet1RouteTableAssociation0B0896DC", + "VPCPublicSubnet2DefaultRouteB7481BBA", + "VPCPublicSubnet2RouteTableAssociation5A808732" + ], + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ALB/Resource" + } + }, + "ProxyALBSecurityGroup0D6CA3DA": { + "Type": "AWS::EC2::SecurityGroup", + "Properties": { + "GroupDescription": "Automatically created Security Group for ELB BedrockProxyALB1CE4CAD1", + "SecurityGroupEgress": [ + { + "CidrIp": "255.255.255.255/32", + "Description": "Disallow all traffic", + "FromPort": 252, + "IpProtocol": "icmp", + "ToPort": 86 + } + ], + "SecurityGroupIngress": [ + { + "CidrIp": "0.0.0.0/0", + "Description": "Allow from anyone on port 80", + "FromPort": 80, + "IpProtocol": "tcp", + "ToPort": 80 + } + ], + "VpcId": { + "Ref": "VPCB9E5F0B4" + } + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ALB/SecurityGroup/Resource" + } + }, + "ProxyALBListener933E9515": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "DefaultActions": [ + { + "TargetGroupArn": { + "Ref": "ProxyALBListenerTargetsGroup187739FA" + }, + "Type": "forward" + } + ], + "LoadBalancerArn": { + "Ref": "ProxyALB87756780" + }, + "Port": 80, + "Protocol": "HTTP" + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ALB/Listener/Resource" + } + }, + "ProxyALBListenerTargetsGroup187739FA": { + "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", + "Properties": { + "HealthCheckEnabled": false, + "TargetType": "lambda", + "Targets": [ + { + "Id": { + "Fn::GetAtt": [ + "ProxyApiHandlerEC15A492", + "Arn" + ] + } + } + ] + }, + "DependsOn": [ + "ProxyApiHandlerInvoke2UTWxhlfyqbT5FTn5jvgbLgjFfJwzswGk55DU1HYF6C33779" + ], + "Metadata": { + "aws:cdk:path": "BedrockProxy/Proxy/ALB/Listener/TargetsGroup/Resource" + } + }, + "CDKMetadata": { + "Type": "AWS::CDK::Metadata", + "Properties": { + "Analytics": "v2:deflate64:H4sIAAAAAAAA/1VRXW/CMAz8LbyHDMovAKZNSJtWFcTr5LpeZ0iTKHFAqOp/n1q+uief7y7ynZLp+WKhZxM4xylWx6nhUrdbATyq9Y/NIUBDQkHBOX63hJlu9x57aZ+vVZ5Kw7hNpSXpuScqXBLaQWnoyT+5ZYwOGYSdfZh7sLFCwZK8g9AZLrczt20pAvjbkBW1JUyB5fIeXPLDgTHRKcKgC/IusrhwWUEkZaApK9Dtq8MjhU0DNb0li/cIY5xTaDhGdrZTDI1uC3etMczcGcYh2hV1igxEYTQOqhIMWGRbnzLdLr03jEPLDwfVatAo9E//7WMfRyF789zxSN9BqEketUdr16mCoksBh6if4D3buodfSXy6fsrIsHa2Yhk6WleRPsSXUzbT87meTQ6ReRqSFW5IF9f5B/Z2H8goAgAA" + }, + "Metadata": { + "aws:cdk:path": "BedrockProxy/CDKMetadata/Default" + }, + "Condition": "CDKMetadataAvailable" + } + }, + "Mappings": { + "ProxyRegionTable03E5BEB3": { + "us-east-1": { + "repoArn": "arn:aws:ecr:us-east-1:366590864501:repository/bedrock-proxy-api", + "model": "anthropic.claude-3-sonnet-20240229-v1:0" + }, + "us-west-2": { + "repoArn": "arn:aws:ecr:us-west-2:366590864501:repository/bedrock-proxy-api", + "model": "anthropic.claude-3-sonnet-20240229-v1:0" + }, + "ap-southeast-1": { + "repoArn": "arn:aws:ecr:ap-southeast-1:366590864501:repository/bedrock-proxy-api", + "model": "anthropic.claude-v2" + }, + "ap-northeast-1": { + "repoArn": "arn:aws:ecr:ap-northeast-1:366590864501:repository/bedrock-proxy-api", + "model": "anthropic.claude-v2:1" + }, + "eu-central-1": { + "repoArn": "arn:aws:ecr:eu-central-1:366590864501:repository/bedrock-proxy-api", + "model": "anthropic.claude-v2:1" + } + } + }, + "Outputs": { + "APIBaseUrl": { + "Description": "Proxy API Base URL (OPENAI_API_BASE)", + "Value": { + "Fn::Join": [ + "", + [ + "http://", + { + "Fn::GetAtt": [ + "ProxyALB87756780", + "DNSName" + ] + }, + "/api/v1" + ] + ] + } + } + }, + "Conditions": { + "CDKMetadataAvailable": { + "Fn::Or": [ + { + "Fn::Or": [ + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "af-south-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ap-east-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ap-northeast-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ap-northeast-2" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ap-south-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ap-southeast-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ap-southeast-2" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "ca-central-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "cn-north-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "cn-northwest-1" + ] + } + ] + }, + { + "Fn::Or": [ + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "eu-central-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "eu-north-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "eu-south-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "eu-west-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "eu-west-2" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "eu-west-3" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "il-central-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "me-central-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "me-south-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "sa-east-1" + ] + } + ] + }, + { + "Fn::Or": [ + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "us-east-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "us-east-2" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "us-west-1" + ] + }, + { + "Fn::Equals": [ + { + "Ref": "AWS::Region" + }, + "us-west-2" + ] + } + ] + } + ] + } + } +} \ No newline at end of file diff --git a/src/Dockerfile b/src/Dockerfile new file mode 100644 index 0000000..920a01e --- /dev/null +++ b/src/Dockerfile @@ -0,0 +1,9 @@ +FROM public.ecr.aws/lambda/python:3.12 + +COPY ./api ./api + +COPY requirements.txt . + +RUN pip3 install -r requirements.txt -U --no-cache-dir + +CMD [ "api.app.handler" ] \ No newline at end of file diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/api/app.py b/src/api/app.py new file mode 100644 index 0000000..cafdf70 --- /dev/null +++ b/src/api/app.py @@ -0,0 +1,52 @@ +import logging + +import uvicorn +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import PlainTextResponse +from mangum import Mangum + +from api.routers import model, chat +from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION + +config = { + "title": TITLE, + "description": DESCRIPTION, + "summary": SUMMARY, + "version": VERSION, +} + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) +app = FastAPI(**config) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(model.router, prefix=API_ROUTE_PREFIX) +app.include_router(chat.router, prefix=API_ROUTE_PREFIX) + + +@app.get("/health") +async def health(): + """For health check if needed""" + return {"status": "OK"} + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return PlainTextResponse(str(exc), status_code=400) + + +handler = Mangum(app) + +if __name__ == "__main__": + uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) diff --git a/src/api/auth.py b/src/api/auth.py new file mode 100644 index 0000000..1593375 --- /dev/null +++ b/src/api/auth.py @@ -0,0 +1,28 @@ +import os +from typing import Annotated + +import boto3 +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from api.setting import DEFAULT_API_KEYS + +api_key_param = os.environ.get("API_KEY_PARAM_NAME") +if api_key_param: + ssm = boto3.client("ssm") + api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"][ + "Value" + ] +else: + api_key = DEFAULT_API_KEYS + +security = HTTPBearer() + + +def api_key_auth( + credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)] +): + if credentials.credentials != api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key" + ) diff --git a/src/api/models/__init__.py b/src/api/models/__init__.py new file mode 100644 index 0000000..6cb3b00 --- /dev/null +++ b/src/api/models/__init__.py @@ -0,0 +1 @@ +from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py new file mode 100644 index 0000000..6bdd132 --- /dev/null +++ b/src/api/models/bedrock.py @@ -0,0 +1,391 @@ +import json +import logging +import uuid +from abc import ABC, abstractmethod +from typing import AsyncIterable + +import boto3 + +from api.schema import ( + ChatResponse, + ChatRequest, + ChatRequestMessage, + Choice, + ChatResponseMessage, + Usage, + ChatStreamResponse, + ChoiceDelta, +) +from api.setting import DEBUG, AWS_REGION + +logger = logging.getLogger(__name__) + +bedrock_runtime = boto3.client( + service_name="bedrock-runtime", + region_name=AWS_REGION, +) + +SUPPORTED_BEDROCK_MODELS = { + "anthropic.claude-instant-v1": "Claude Instant", + "anthropic.claude-v2:1": "Claude", + "anthropic.claude-v2": "Claude", + "anthropic.claude-3-sonnet-20240229-v1:0": "Claude 3 Sonnet", + "anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku", + "meta.llama2-13b-chat-v1": "Llama 2 Chat 13B", + "meta.llama2-70b-chat-v1": "Llama 2 Chat 70B", + "mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct", + "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", +} + + +class BaseChatModel(ABC): + """Represent a basic chat model + + Currently, only Bedrock model is supported, but may be used for SageMaker models if needed. + """ + + @abstractmethod + def chat(self, chat_request: ChatRequest) -> ChatResponse: + """Handle a basic chat completion requests.""" + pass + + @abstractmethod + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + """Handle a basic chat completion requests with stream response.""" + pass + + def _generate_message_id(self) -> str: + return "chatcmpl-" + str(uuid.uuid4())[:8] + + def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes: + return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8") + + +# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html +class BedrockModel(BaseChatModel): + accept = "application/json" + content_type = "application/json" + + def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False): + body = json.dumps(args) + if DEBUG: + logger.info("Invoke Bedrock Model: " + model_id) + logger.info("Bedrock request body: " + body) + if with_stream: + return bedrock_runtime.invoke_model_with_response_stream( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) + return bedrock_runtime.invoke_model( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) + + def _create_response( + self, + model: str, + message: str, + message_id: str, + input_tokens: int = 0, + output_tokens: int = 0, + ) -> ChatResponse: + choice = Choice( + index=0, + message=ChatResponseMessage( + role="assistant", + content=message, + ), + finish_reason="stop", + ) + response = ChatResponse( + id=message_id, + model=model, + choices=[choice], + usage=Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ), + ) + if DEBUG: + logger.info("Proxy response :" + response.model_dump_json()) + return response + + def _create_response_stream( + self, model: str, message_id: str, chunk_message: str, finish_reason: str | None + ) -> ChatStreamResponse: + choice = ChoiceDelta( + index=0, + delta=ChatResponseMessage( + role="assistant", + content=chunk_message, + ), + finish_reason=finish_reason, + ) + response = ChatStreamResponse( + id=message_id, + model=model, + choices=[choice], + ) + if DEBUG: + logger.info("Proxy response :" + response.model_dump_json()) + return response + + +def get_model(model_id: str) -> BedrockModel: + model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "") + if DEBUG: + logger.info("model name is " + model_name) + if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]: + return ClaudeModel() + elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]: + return Llama2Model() + elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]: + return MistralModel() + else: + logger.error("Unsupported model id " + model_id) + raise ValueError("Invalid model ID") + + +class ClaudeModel(BedrockModel): + anthropic_version = "bedrock-2023-05-31" + + def _parse_args(self, chat_request: ChatRequest) -> dict: + args = { + "anthropic_version": self.anthropic_version, + "max_tokens": chat_request.max_tokens, + "top_p": chat_request.top_p, + "temperature": chat_request.temperature, + } + if chat_request.messages[0].role == "system": + args["system"] = chat_request.messages[0].content + args["messages"] = [ + {"role": msg.role, "content": msg.content} + for msg in chat_request.messages[1:] + ] + else: + args["messages"] = [ + {"role": msg.role, "content": msg.content} + for msg in chat_request.messages + ] + + return args + + def chat(self, chat_request: ChatRequest) -> ChatResponse: + response = self._invoke_model( + args=self._parse_args(chat_request), model_id=chat_request.model + ) + response_body = json.loads(response.get("body").read()) + if DEBUG: + logger.info("Bedrock response body: " + str(response_body)) + + return self._create_response( + model=chat_request.model, + message=response_body["content"][0]["text"], + message_id=response_body["id"], + input_tokens=response_body["usage"]["input_tokens"], + output_tokens=response_body["usage"]["output_tokens"], + ) + + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + response = self._invoke_model( + args=self._parse_args(chat_request), + model_id=chat_request.model, + with_stream=True, + ) + msg_id = "" + chunk_id = 0 + for event in response.get("body"): + if DEBUG: + logger.info("Bedrock response chunk: " + str(event)) + chunk = json.loads(event["chunk"]["bytes"]) + chunk_id += 1 + if chunk["type"] == "message_start": + msg_id = chunk["message"]["id"] + continue + + if chunk["type"] == "message_delta": + chunk_message = "" + finish_reason = "stop" + + elif chunk["type"] == "content_block_delta": + chunk_message = chunk["delta"]["text"] + finish_reason = None + else: + continue + response = self._create_response_stream( + model=chat_request.model, + message_id=msg_id, + chunk_message=chunk_message, + finish_reason=finish_reason, + ) + + yield self._stream_response_to_bytes(response) + + +class Llama2Model(BedrockModel): + + def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str: + """Create a prompt message follow below example: + + [INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1} + [INST] {user_message_2} [/INST] + """ + if DEBUG: + logger.info("Convert below messages to prompt for Llama 2: ") + for msg in messages: + logger.info(msg.model_dump_json()) + bos_token = "" + eos_token = "" + prompt = bos_token + "[INST] " + start = 0 + end_turn = False + if messages[0].role == "system": + prompt += "<>\n" + messages[0].content + "\n<>\n\n" + start = 1 + # TODO: Add validation + for i in range(start, len(messages)): + msg = messages[i] + if msg.role == "user": + if end_turn: + prompt += bos_token + "[INST] " + prompt += msg.content + " [/INST] " + end_turn = False + else: + prompt += msg.content + eos_token + end_turn = True + if DEBUG: + logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) + return prompt + + def _parse_args(self, chat_request: ChatRequest) -> dict: + prompt = self._convert_prompt(chat_request.messages) + return { + "prompt": prompt, + "max_gen_len": chat_request.max_tokens, + "temperature": chat_request.temperature, + "top_p": chat_request.top_p, + } + + def chat(self, chat_request: ChatRequest) -> ChatResponse: + response = self._invoke_model( + args=self._parse_args(chat_request), model_id=chat_request.model + ) + response_body = json.loads(response.get("body").read()) + if DEBUG: + logger.info("Bedrock response body: " + str(response_body)) + message_id = self._generate_message_id() + + return self._create_response( + model=chat_request.model, + message=response_body["generation"], + message_id=message_id, + input_tokens=response_body["prompt_token_count"], + output_tokens=response_body["generation_token_count"], + ) + + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + response = self._invoke_model( + args=self._parse_args(chat_request), + model_id=chat_request.model, + with_stream=True, + ) + msg_id = "" + chunk_id = 0 + for event in response.get("body"): + if DEBUG: + logger.info("Bedrock response chunk: " + str(event)) + chunk = json.loads(event["chunk"]["bytes"]) + chunk_id += 1 + response = self._create_response_stream( + model=chat_request.model, + message_id=msg_id, + chunk_message=chunk["generation"], + finish_reason=chunk["stop_reason"], + ) + yield self._stream_response_to_bytes(response) + + +class MistralModel(BedrockModel): + def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str: + """Create a prompt message follow below example: + + [INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1} + [INST] {user_message_2} [/INST] + """ + if DEBUG: + logger.info("Convert below messages to prompt for Llama 2: ") + for msg in messages: + logger.info(msg.model_dump_json()) + bos_token = "" + eos_token = "" + prompt = bos_token + "[INST] " + start = 0 + end_turn = False + if messages[0].role == "system": + prompt += messages[0].content + "\n" + start = 1 + # TODO: Add validation + for i in range(start, len(messages)): + msg = messages[i] + if msg.role == "user": + if end_turn: + prompt += bos_token + "[INST] " + prompt += msg.content + " [/INST] " + end_turn = False + else: + prompt += msg.content + eos_token + end_turn = True + if DEBUG: + logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) + return prompt + + def _parse_args(self, chat_request: ChatRequest) -> dict: + prompt = self._convert_prompt(chat_request.messages) + return { + "prompt": prompt, + "max_tokens": chat_request.max_tokens, + "temperature": chat_request.temperature, + "top_p": chat_request.top_p, + } + + def chat(self, chat_request: ChatRequest) -> ChatResponse: + + response = self._invoke_model( + args=self._parse_args(chat_request), model_id=chat_request.model + ) + response_body = json.loads(response.get("body").read()) + if DEBUG: + logger.info("Bedrock response body: " + str(response_body)) + message_id = self._generate_message_id() + + return self._create_response( + model=chat_request.model, + message=response_body["outputs"][0]["text"], + message_id=message_id, + ) + + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + response = self._invoke_model( + args=self._parse_args(chat_request), + model_id=chat_request.model, + with_stream=True, + ) + msg_id = "" + chunk_id = 0 + for event in response.get("body"): + if DEBUG: + logger.info("Bedrock response chunk: " + str(event)) + chunk = json.loads(event["chunk"]["bytes"]) + chunk_id += 1 + response = self._create_response_stream( + model=chat_request.model, + message_id=msg_id, + chunk_message=chunk["outputs"][0]["text"], + finish_reason=chunk["outputs"][0]["stop_reason"], + ) + yield self._stream_response_to_bytes(response) diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py new file mode 100644 index 0000000..d9b2fa6 --- /dev/null +++ b/src/api/routers/chat.py @@ -0,0 +1,51 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Body, HTTPException +from fastapi.responses import StreamingResponse + +from api.auth import api_key_auth +from api.models import get_model, SUPPORTED_BEDROCK_MODELS +from api.schema import ChatRequest, ChatResponse, ChatStreamResponse +from api.setting import DEFAULT_MODEL + +router = APIRouter() + +router = APIRouter( + prefix="/chat", + tags=["items"], + dependencies=[Depends(api_key_auth)], + # responses={404: {"description": "Not found"}}, +) + + +@router.post("/completions", response_model=ChatResponse | ChatStreamResponse) +async def chat_completions( + chat_request: Annotated[ + ChatRequest, + Body( + examples=[ + { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + ], + ), + ] +): + if chat_request.model.lower().startswith("gpt-"): + chat_request.model = DEFAULT_MODEL + if chat_request.model not in SUPPORTED_BEDROCK_MODELS.keys(): + raise HTTPException(status_code=400, detail="Unsupported Model Id " + chat_request.model) + try: + model = get_model(chat_request.model) + + if chat_request.stream: + return StreamingResponse( + content=model.chat_stream(chat_request), media_type="text/event-stream" + ) + return model.chat(chat_request) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/api/routers/model.py b/src/api/routers/model.py new file mode 100644 index 0000000..4d10f98 --- /dev/null +++ b/src/api/routers/model.py @@ -0,0 +1,41 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Path + +from api.auth import api_key_auth +from api.models import SUPPORTED_BEDROCK_MODELS +from api.schema import Models, Model + +router = APIRouter() + +router = APIRouter( + prefix="/models", + tags=["items"], + dependencies=[Depends(api_key_auth)], + # responses={404: {"description": "Not found"}}, +) + + +async def validate_model_id(model_id: str): + if model_id not in SUPPORTED_BEDROCK_MODELS.keys(): + raise HTTPException(status_code=400, detail="Unsupported Model Id") + + +@router.get("/", response_model=Models) +async def list_models(): + model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()] + return Models(data=model_list) + + +@router.get( + "/{model_id}", + response_model=Model, +) +async def get_model( + model_id: Annotated[ + str, + Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), + ] +): + await validate_model_id(model_id) + return Model(id=model_id) diff --git a/src/api/schema.py b/src/api/schema.py new file mode 100644 index 0000000..53732c8 --- /dev/null +++ b/src/api/schema.py @@ -0,0 +1,80 @@ +import time +from typing import Literal + +from pydantic import BaseModel, Field + + +class Model(BaseModel): + id: str + created: int = Field(default_factory=lambda: int(time.time())) + object: str | None = "model" + owned_by: str | None = "bedrock" + + +class Models(BaseModel): + object: str | None = "list" + data: list[Model] = [] + + +class ChatRequestMessage(BaseModel): + name: str | None = None + role: Literal["user", "assistant", "system"] + content: str + + +class ChatRequest(BaseModel): + messages: list[ChatRequestMessage] + model: str + frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used + presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used + stream: bool | None = False + temperature: float | None = Field(default=1.0, le=2.0, ge=0.0) + top_p: float | None = Field(default=1.0, le=1.0, ge=0.0) + user: str | None = None # Not used + max_tokens: int | None = 2048 + n: int | None = 1 # Not used + + +class Usage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatResponseMessage(BaseModel): + # tool_calls + role: Literal["assistant"] | None = None + content: str | None = None + + +class BaseChoice(BaseModel): + index: int + finish_reason: str | None + logprobs: dict | None = None + + +class Choice(BaseChoice): + message: ChatResponseMessage + + +class ChoiceDelta(BaseChoice): + delta: ChatResponseMessage + + +class BaseChatResponse(BaseModel): + # id: str = Field(default_factory=lambda: "chatcmpl-" + str(uuid.uuid4())[:8]) + id: str + created: int = Field(default_factory=lambda: int(time.time())) + model: str + system_fingerprint: str = "fp_e97c09dd4e26" + + +class ChatResponse(BaseChatResponse): + choices: list[Choice] + object: Literal["chat.completion"] = "chat.completion" + usage: Usage + + +class ChatStreamResponse(BaseChatResponse): + choices: list[ChoiceDelta] + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" diff --git a/src/api/setting.py b/src/api/setting.py new file mode 100644 index 0000000..f183aa2 --- /dev/null +++ b/src/api/setting.py @@ -0,0 +1,27 @@ +import os + +DEFAULT_API_KEYS = "bedrock" + +API_ROUTE_PREFIX = "/api/v1" + +TITLE = "Amazon Bedrock Proxy APIs" +SUMMARY = "OpenAI-Compatible RESTful APIs for Amazon Bedrock" +VERSION = "0.1.0" +DESCRIPTION = """ +Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. + +List of Amazon Bedrock models currently supported: +- anthropic.claude-instant-v1 +- anthropic.claude-v2:1 +- anthropic.claude-v2 +- anthropic.claude-3-sonnet-20240229-v1:0 +- anthropic.claude-3-haiku-20240307-v1:0 +- meta.llama2-13b-chat-v1 +- meta.llama2-70b-chat-v1 +- mistral.mistral-7b-instruct-v0:2 +- mistral.mixtral-8x7b-instruct-v0:1 +""" + +DEBUG = os.environ.get("DEBUG", "false").lower() != "false" +AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") +DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") diff --git a/src/requirements.txt b/src/requirements.txt new file mode 100644 index 0000000..fec17d7 --- /dev/null +++ b/src/requirements.txt @@ -0,0 +1,4 @@ +fastapi==0.103.0 +pydantic==2.6.3 +uvicorn==0.27.0.post1 +mangum==0.17.0 \ No newline at end of file