Add polling
Browse files- .env.example +1 -0
- .gitignore +2 -1
- main.py +85 -46
.env.example
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
USE_ROUND_ROBIN=true
|
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
api.json
|
| 2 |
-
api.yaml
|
|
|
|
|
|
| 1 |
api.json
|
| 2 |
+
api.yaml
|
| 3 |
+
.env
|
main.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import httpx
|
| 2 |
import yaml
|
| 3 |
from contextlib import asynccontextmanager
|
|
@@ -82,56 +83,94 @@ async def fetch_response(client, url, headers, payload):
|
|
| 82 |
# print(response.text)
|
| 83 |
return response.json()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
try:
|
| 127 |
-
|
| 128 |
-
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
|
| 129 |
-
else:
|
| 130 |
-
return await fetch_response(app.state.client, url, headers, payload)
|
| 131 |
except Exception as e:
|
| 132 |
raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
|
| 133 |
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
if __name__ == '__main__':
|
| 137 |
import uvicorn
|
|
|
|
| 1 |
+
import os
|
| 2 |
import httpx
|
| 3 |
import yaml
|
| 4 |
from contextlib import asynccontextmanager
|
|
|
|
| 83 |
# print(response.text)
|
| 84 |
return response.json()
|
| 85 |
|
| 86 |
+
async def process_request(request: RequestModel, provider: Dict):
|
| 87 |
+
print("provider: ", provider['provider'])
|
| 88 |
+
url = provider['base_url']
|
| 89 |
+
headers = {
|
| 90 |
+
'Authorization': f"Bearer {provider['api']}",
|
| 91 |
+
'Content-Type': 'application/json'
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# 转换消息格式
|
| 95 |
+
messages = []
|
| 96 |
+
for msg in request.messages:
|
| 97 |
+
if isinstance(msg.content, list):
|
| 98 |
+
content = " ".join([item.text for item in msg.content if item.type == "text"])
|
| 99 |
+
else:
|
| 100 |
+
content = msg.content
|
| 101 |
+
messages.append({"role": msg.role, "content": content})
|
| 102 |
+
|
| 103 |
+
payload = {
|
| 104 |
+
"model": request.model,
|
| 105 |
+
"messages": messages
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# 只有当相应参数存在且不为None时,才添加到payload中
|
| 109 |
+
if request.stream is not None:
|
| 110 |
+
payload["stream"] = request.stream
|
| 111 |
+
if request.include_usage is not None:
|
| 112 |
+
payload["include_usage"] = request.include_usage
|
| 113 |
+
|
| 114 |
+
if provider['provider'] == 'anthropic':
|
| 115 |
+
payload["max_tokens"] = 1000 # 您可能想让这个可配置
|
| 116 |
+
else:
|
| 117 |
+
if request.logprobs is not None:
|
| 118 |
+
payload["logprobs"] = request.logprobs
|
| 119 |
+
if request.top_logprobs is not None:
|
| 120 |
+
payload["top_logprobs"] = request.top_logprobs
|
| 121 |
+
|
| 122 |
+
if request.stream:
|
| 123 |
+
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
|
| 124 |
+
else:
|
| 125 |
+
return await fetch_response(app.state.client, url, headers, payload)
|
| 126 |
+
|
| 127 |
+
class ModelRequestHandler:
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.last_provider_index = -1
|
| 130 |
+
|
| 131 |
+
def get_matching_providers(self, model_name):
|
| 132 |
+
return [provider for provider in config if model_name in provider['model']]
|
| 133 |
+
|
| 134 |
+
async def request_model(self, request: RequestModel, token: str):
|
| 135 |
+
model_name = request.model
|
| 136 |
+
matching_providers = self.get_matching_providers(model_name)
|
| 137 |
+
print("matching_providers", matching_providers)
|
| 138 |
+
|
| 139 |
+
if not matching_providers:
|
| 140 |
+
raise HTTPException(status_code=404, detail="No matching model found")
|
| 141 |
+
|
| 142 |
+
# 检查是否启用轮询
|
| 143 |
+
use_round_robin = os.environ.get('USE_ROUND_ROBIN', 'false').lower() == 'true'
|
| 144 |
+
|
| 145 |
+
if use_round_robin:
|
| 146 |
+
return await self.round_robin_request(request, matching_providers)
|
| 147 |
+
else:
|
| 148 |
+
# 使用第一个匹配的提供者
|
| 149 |
+
provider = matching_providers[0]
|
| 150 |
try:
|
| 151 |
+
return await process_request(request, provider)
|
|
|
|
|
|
|
|
|
|
| 152 |
except Exception as e:
|
| 153 |
raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
|
| 154 |
|
| 155 |
+
async def round_robin_request(self, request: RequestModel, providers: List[Dict]):
|
| 156 |
+
num_providers = len(providers)
|
| 157 |
+
for i in range(num_providers):
|
| 158 |
+
self.last_provider_index = (self.last_provider_index + 1) % num_providers
|
| 159 |
+
# print(f"Trying provider {self.last_provider_index}")
|
| 160 |
+
provider = providers[self.last_provider_index]
|
| 161 |
+
try:
|
| 162 |
+
response = await process_request(request, provider)
|
| 163 |
+
return response
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Error with provider {provider['provider']}: {str(e)}")
|
| 166 |
+
continue
|
| 167 |
+
raise HTTPException(status_code=500, detail="All providers failed")
|
| 168 |
+
|
| 169 |
+
model_handler = ModelRequestHandler()
|
| 170 |
+
|
| 171 |
+
@app.post("/v1/chat/completions")
|
| 172 |
+
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
| 173 |
+
return await model_handler.request_model(request, token)
|
| 174 |
|
| 175 |
if __name__ == '__main__':
|
| 176 |
import uvicorn
|