✨ Feature: Add features: Add API channel success rate statistics, channel status records.
Browse files- main.py +63 -22
- request.py +6 -6
- response.py +2 -2
- test/test_nostream.py +1 -1
main.py
CHANGED
|
@@ -58,6 +58,8 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 58 |
self.request_times = defaultdict(float)
|
| 59 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
| 60 |
self.request_arrivals = defaultdict(list)
|
|
|
|
|
|
|
| 61 |
self.lock = asyncio.Lock()
|
| 62 |
self.exclude_paths = set(exclude_paths or [])
|
| 63 |
self.save_interval = save_interval
|
|
@@ -101,7 +103,11 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 101 |
"request_counts": dict(self.request_counts),
|
| 102 |
"request_times": dict(self.request_times),
|
| 103 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
| 104 |
-
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
}
|
| 106 |
|
| 107 |
filename = self.filename
|
|
@@ -109,10 +115,28 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 109 |
await f.write(json.dumps(stats, indent=2))
|
| 110 |
|
| 111 |
self.last_save_time = current_time
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
async def cleanup_old_data(self):
|
| 115 |
-
# cutoff_time = datetime.now() - timedelta(seconds=30)
|
| 116 |
cutoff_time = datetime.now() - timedelta(hours=24)
|
| 117 |
async with self.lock:
|
| 118 |
for endpoint in list(self.request_arrivals.keys()):
|
|
@@ -139,10 +163,10 @@ app.add_middleware(
|
|
| 139 |
|
| 140 |
app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
|
| 141 |
|
|
|
|
| 142 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
| 143 |
url = provider['base_url']
|
| 144 |
parsed_url = urlparse(url)
|
| 145 |
-
# print(parsed_url)
|
| 146 |
engine = None
|
| 147 |
if parsed_url.netloc == 'generativelanguage.googleapis.com':
|
| 148 |
engine = "gemini"
|
|
@@ -160,6 +184,12 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
| 160 |
and "gemini" not in provider['model'][request.model]:
|
| 161 |
engine = "openrouter"
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if endpoint == "/v1/images/generations":
|
| 164 |
engine = "dalle"
|
| 165 |
request.stream = False
|
|
@@ -171,21 +201,28 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
| 171 |
|
| 172 |
url, headers, payload = await get_payload(request, engine, provider)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
import asyncio
|
| 191 |
class ModelRequestHandler:
|
|
@@ -270,10 +307,10 @@ class ModelRequestHandler:
|
|
| 270 |
|
| 271 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
| 272 |
|
|
|
|
| 273 |
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
| 274 |
num_providers = len(providers)
|
| 275 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
| 276 |
-
|
| 277 |
for i in range(num_providers + 1):
|
| 278 |
self.last_provider_index = (start_index + i) % num_providers
|
| 279 |
provider = providers[self.last_provider_index]
|
|
@@ -287,7 +324,6 @@ class ModelRequestHandler:
|
|
| 287 |
else:
|
| 288 |
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
|
| 289 |
|
| 290 |
-
|
| 291 |
raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
|
| 292 |
|
| 293 |
model_handler = ModelRequestHandler()
|
|
@@ -341,6 +377,7 @@ def generate_api_key():
|
|
| 341 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
| 342 |
return JSONResponse(content={"api_key": api_key})
|
| 343 |
|
|
|
|
| 344 |
@app.get("/stats")
|
| 345 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
| 346 |
middleware = app.middleware_stack.app
|
|
@@ -350,7 +387,11 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
|
|
| 350 |
"request_counts": dict(middleware.request_counts),
|
| 351 |
"request_times": dict(middleware.request_times),
|
| 352 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
| 353 |
-
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
}
|
| 355 |
return JSONResponse(content=stats)
|
| 356 |
return {"error": "StatsMiddleware not found"}
|
|
|
|
| 58 |
self.request_times = defaultdict(float)
|
| 59 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
| 60 |
self.request_arrivals = defaultdict(list)
|
| 61 |
+
self.channel_success_counts = defaultdict(int)
|
| 62 |
+
self.channel_failure_counts = defaultdict(int)
|
| 63 |
self.lock = asyncio.Lock()
|
| 64 |
self.exclude_paths = set(exclude_paths or [])
|
| 65 |
self.save_interval = save_interval
|
|
|
|
| 103 |
"request_counts": dict(self.request_counts),
|
| 104 |
"request_times": dict(self.request_times),
|
| 105 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
| 106 |
+
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
|
| 107 |
+
"channel_success_counts": dict(self.channel_success_counts),
|
| 108 |
+
"channel_failure_counts": dict(self.channel_failure_counts),
|
| 109 |
+
"channel_success_percentages": self.calculate_success_percentages(),
|
| 110 |
+
"channel_failure_percentages": self.calculate_failure_percentages()
|
| 111 |
}
|
| 112 |
|
| 113 |
filename = self.filename
|
|
|
|
| 115 |
await f.write(json.dumps(stats, indent=2))
|
| 116 |
|
| 117 |
self.last_save_time = current_time
|
| 118 |
+
|
| 119 |
+
def calculate_success_percentages(self):
|
| 120 |
+
percentages = {}
|
| 121 |
+
for channel, success_count in self.channel_success_counts.items():
|
| 122 |
+
total_count = success_count + self.channel_failure_counts[channel]
|
| 123 |
+
if total_count > 0:
|
| 124 |
+
percentages[channel] = success_count / total_count * 100
|
| 125 |
+
else:
|
| 126 |
+
percentages[channel] = 0
|
| 127 |
+
return percentages
|
| 128 |
+
|
| 129 |
+
def calculate_failure_percentages(self):
|
| 130 |
+
percentages = {}
|
| 131 |
+
for channel, failure_count in self.channel_failure_counts.items():
|
| 132 |
+
total_count = failure_count + self.channel_success_counts[channel]
|
| 133 |
+
if total_count > 0:
|
| 134 |
+
percentages[channel] = failure_count / total_count * 100
|
| 135 |
+
else:
|
| 136 |
+
percentages[channel] = 0
|
| 137 |
+
return percentages
|
| 138 |
|
| 139 |
async def cleanup_old_data(self):
|
|
|
|
| 140 |
cutoff_time = datetime.now() - timedelta(hours=24)
|
| 141 |
async with self.lock:
|
| 142 |
for endpoint in list(self.request_arrivals.keys()):
|
|
|
|
| 163 |
|
| 164 |
app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
|
| 165 |
|
| 166 |
+
# 在 process_request 函数中更新成功和失败计数
|
| 167 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
| 168 |
url = provider['base_url']
|
| 169 |
parsed_url = urlparse(url)
|
|
|
|
| 170 |
engine = None
|
| 171 |
if parsed_url.netloc == 'generativelanguage.googleapis.com':
|
| 172 |
engine = "gemini"
|
|
|
|
| 184 |
and "gemini" not in provider['model'][request.model]:
|
| 185 |
engine = "openrouter"
|
| 186 |
|
| 187 |
+
if "claude" in provider['model'][request.model] and engine == "vertex":
|
| 188 |
+
engine = "vertex-claude"
|
| 189 |
+
|
| 190 |
+
if "gemini" in provider['model'][request.model] and engine == "vertex":
|
| 191 |
+
engine = "vertex-gemini"
|
| 192 |
+
|
| 193 |
if endpoint == "/v1/images/generations":
|
| 194 |
engine = "dalle"
|
| 195 |
request.stream = False
|
|
|
|
| 201 |
|
| 202 |
url, headers, payload = await get_payload(request, engine, provider)
|
| 203 |
|
| 204 |
+
try:
|
| 205 |
+
if request.stream:
|
| 206 |
+
model = provider['model'][request.model]
|
| 207 |
+
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
| 208 |
+
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
|
| 209 |
+
response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
|
| 210 |
+
else:
|
| 211 |
+
response = await anext(fetch_response(app.state.client, url, headers, payload))
|
| 212 |
+
|
| 213 |
+
# 更新成功计数
|
| 214 |
+
async with app.middleware_stack.app.lock:
|
| 215 |
+
app.middleware_stack.app.channel_success_counts[provider['provider']] += 1
|
| 216 |
+
|
| 217 |
+
return response
|
| 218 |
+
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
|
| 219 |
+
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
|
| 220 |
+
|
| 221 |
+
# 更新失败计数
|
| 222 |
+
async with app.middleware_stack.app.lock:
|
| 223 |
+
app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1
|
| 224 |
+
|
| 225 |
+
raise e
|
| 226 |
|
| 227 |
import asyncio
|
| 228 |
class ModelRequestHandler:
|
|
|
|
| 307 |
|
| 308 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
| 309 |
|
| 310 |
+
# 在 try_all_providers 函数中处理失败的情况
|
| 311 |
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
|
| 312 |
num_providers = len(providers)
|
| 313 |
start_index = self.last_provider_index + 1 if use_round_robin else 0
|
|
|
|
| 314 |
for i in range(num_providers + 1):
|
| 315 |
self.last_provider_index = (start_index + i) % num_providers
|
| 316 |
provider = providers[self.last_provider_index]
|
|
|
|
| 324 |
else:
|
| 325 |
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
|
| 326 |
|
|
|
|
| 327 |
raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
|
| 328 |
|
| 329 |
model_handler = ModelRequestHandler()
|
|
|
|
| 377 |
api_key = "sk-" + secrets.token_urlsafe(36)
|
| 378 |
return JSONResponse(content={"api_key": api_key})
|
| 379 |
|
| 380 |
+
# 在 /stats 路由中返回成功和失败百分比
|
| 381 |
@app.get("/stats")
|
| 382 |
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
| 383 |
middleware = app.middleware_stack.app
|
|
|
|
| 387 |
"request_counts": dict(middleware.request_counts),
|
| 388 |
"request_times": dict(middleware.request_times),
|
| 389 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
| 390 |
+
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()},
|
| 391 |
+
"channel_success_counts": dict(middleware.channel_success_counts),
|
| 392 |
+
"channel_failure_counts": dict(middleware.channel_failure_counts),
|
| 393 |
+
"channel_success_percentages": middleware.calculate_success_percentages(),
|
| 394 |
+
"channel_failure_percentages": middleware.calculate_failure_percentages()
|
| 395 |
}
|
| 396 |
return JSONResponse(content=stats)
|
| 397 |
return {"error": "StatsMiddleware not found"}
|
request.py
CHANGED
|
@@ -10,7 +10,7 @@ async def get_image_message(base64_image, engine = None):
|
|
| 10 |
"url": base64_image,
|
| 11 |
}
|
| 12 |
}
|
| 13 |
-
if "claude" == engine:
|
| 14 |
return {
|
| 15 |
"type": "image",
|
| 16 |
"source": {
|
|
@@ -19,7 +19,7 @@ async def get_image_message(base64_image, engine = None):
|
|
| 19 |
"data": base64_image.split(",")[1],
|
| 20 |
}
|
| 21 |
}
|
| 22 |
-
if "gemini" == engine:
|
| 23 |
return {
|
| 24 |
"inlineData": {
|
| 25 |
"mimeType": "image/jpeg",
|
|
@@ -29,9 +29,9 @@ async def get_image_message(base64_image, engine = None):
|
|
| 29 |
raise ValueError("Unknown engine")
|
| 30 |
|
| 31 |
async def get_text_message(role, message, engine = None):
|
| 32 |
-
if "gpt" == engine or "claude" == engine or "openrouter" == engine:
|
| 33 |
return {"type": "text", "text": message}
|
| 34 |
-
if "gemini" == engine:
|
| 35 |
return {"text": message}
|
| 36 |
raise ValueError("Unknown engine")
|
| 37 |
|
|
@@ -794,9 +794,9 @@ async def get_dalle_payload(request, engine, provider):
|
|
| 794 |
async def get_payload(request: RequestModel, engine, provider):
|
| 795 |
if engine == "gemini":
|
| 796 |
return await get_gemini_payload(request, engine, provider)
|
| 797 |
-
elif engine == "vertex
|
| 798 |
return await get_vertex_gemini_payload(request, engine, provider)
|
| 799 |
-
elif engine == "vertex
|
| 800 |
return await get_vertex_claude_payload(request, engine, provider)
|
| 801 |
elif engine == "claude":
|
| 802 |
return await get_claude_payload(request, engine, provider)
|
|
|
|
| 10 |
"url": base64_image,
|
| 11 |
}
|
| 12 |
}
|
| 13 |
+
if "claude" == engine or "vertex-claude" == engine:
|
| 14 |
return {
|
| 15 |
"type": "image",
|
| 16 |
"source": {
|
|
|
|
| 19 |
"data": base64_image.split(",")[1],
|
| 20 |
}
|
| 21 |
}
|
| 22 |
+
if "gemini" == engine or "vertex-gemini" == engine:
|
| 23 |
return {
|
| 24 |
"inlineData": {
|
| 25 |
"mimeType": "image/jpeg",
|
|
|
|
| 29 |
raise ValueError("Unknown engine")
|
| 30 |
|
| 31 |
async def get_text_message(role, message, engine = None):
|
| 32 |
+
if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine:
|
| 33 |
return {"type": "text", "text": message}
|
| 34 |
+
if "gemini" == engine or "vertex-gemini" == engine:
|
| 35 |
return {"text": message}
|
| 36 |
raise ValueError("Unknown engine")
|
| 37 |
|
|
|
|
| 794 |
async def get_payload(request: RequestModel, engine, provider):
|
| 795 |
if engine == "gemini":
|
| 796 |
return await get_gemini_payload(request, engine, provider)
|
| 797 |
+
elif engine == "vertex-gemini":
|
| 798 |
return await get_vertex_gemini_payload(request, engine, provider)
|
| 799 |
+
elif engine == "vertex-claude":
|
| 800 |
return await get_vertex_claude_payload(request, engine, provider)
|
| 801 |
elif engine == "claude":
|
| 802 |
return await get_claude_payload(request, engine, provider)
|
response.py
CHANGED
|
@@ -248,10 +248,10 @@ async def fetch_response(client, url, headers, payload):
|
|
| 248 |
|
| 249 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 250 |
try:
|
| 251 |
-
if engine == "gemini" or
|
| 252 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 253 |
yield chunk
|
| 254 |
-
elif engine == "claude" or
|
| 255 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
| 256 |
yield chunk
|
| 257 |
elif engine == "gpt":
|
|
|
|
| 248 |
|
| 249 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 250 |
try:
|
| 251 |
+
if engine == "gemini" or engine == "vertex-gemini":
|
| 252 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 253 |
yield chunk
|
| 254 |
+
elif engine == "claude" or engine == "vertex-claude":
|
| 255 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
| 256 |
yield chunk
|
| 257 |
elif engine == "gpt":
|
test/test_nostream.py
CHANGED
|
@@ -66,7 +66,7 @@ def get_model_response(image_base64):
|
|
| 66 |
# "stream": True,
|
| 67 |
"tools": tools,
|
| 68 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
| 69 |
-
"max_tokens":
|
| 70 |
}
|
| 71 |
|
| 72 |
try:
|
|
|
|
| 66 |
# "stream": True,
|
| 67 |
"tools": tools,
|
| 68 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
| 69 |
+
"max_tokens": 1000
|
| 70 |
}
|
| 71 |
|
| 72 |
try:
|