🐛 Bug: 1. Fix the bug where the API key is not found when rate limiting.
Browse files2. Fix the bug where the characters before the slash in the model name with a slash are parsed as the channel name.
main.py
CHANGED
|
@@ -275,20 +275,26 @@ class ModelRequestHandler:
|
|
| 275 |
for model in config['api_keys'][api_index]['model']:
|
| 276 |
if "/" in model:
|
| 277 |
provider_name = model.split("/")[0]
|
| 278 |
-
|
| 279 |
models_list = []
|
| 280 |
for provider in config['providers']:
|
| 281 |
if provider['provider'] == provider_name:
|
| 282 |
models_list.extend(list(provider['model'].keys()))
|
| 283 |
# print("models_list", models_list)
|
| 284 |
# print("model_name", model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
# print("model", model)
|
| 286 |
-
if (
|
| 287 |
provider_rules.append(provider_name)
|
| 288 |
else:
|
| 289 |
for provider in config['providers']:
|
| 290 |
if model in provider['model'].keys():
|
| 291 |
-
provider_rules.append(provider['provider'] + "/" +
|
| 292 |
|
| 293 |
provider_list = []
|
| 294 |
# print("provider_rules", provider_rules)
|
|
@@ -297,7 +303,7 @@ class ModelRequestHandler:
|
|
| 297 |
# print("provider", provider, provider['provider'] == item, item)
|
| 298 |
if "/" in item:
|
| 299 |
if provider['provider'] == item.split("/")[0]:
|
| 300 |
-
if model_name in provider['model'].keys() and item.split("/")[1] == model_name:
|
| 301 |
provider_list.append(provider)
|
| 302 |
elif provider['provider'] == item:
|
| 303 |
if model_name in provider['model'].keys():
|
|
@@ -422,15 +428,13 @@ class InMemoryRateLimiter:
|
|
| 422 |
|
| 423 |
rate_limiter = InMemoryRateLimiter()
|
| 424 |
|
| 425 |
-
async def get_user_rate_limit(
|
| 426 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
| 427 |
# 示例: 返回 (次数, 秒数)
|
| 428 |
config = app.state.config
|
| 429 |
-
api_list = app.state.api_list
|
| 430 |
-
api_index = api_list.index(token)
|
| 431 |
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
|
| 432 |
|
| 433 |
-
if not
|
| 434 |
return (60, 60)
|
| 435 |
|
| 436 |
rate_limit = parse_rate_limit(raw_rate_limit)
|
|
@@ -439,8 +443,14 @@ async def get_user_rate_limit(token: str = None):
|
|
| 439 |
security = HTTPBearer()
|
| 440 |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 441 |
token = credentials.credentials if credentials else None
|
| 442 |
-
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
# 使用 IP 地址和 token(如果有)作为限制键
|
| 446 |
client_ip = request.client.host
|
|
|
|
| 275 |
for model in config['api_keys'][api_index]['model']:
|
| 276 |
if "/" in model:
|
| 277 |
provider_name = model.split("/")[0]
|
| 278 |
+
model_name_split = "/".join(model.split("/")[1:])
|
| 279 |
models_list = []
|
| 280 |
for provider in config['providers']:
|
| 281 |
if provider['provider'] == provider_name:
|
| 282 |
models_list.extend(list(provider['model'].keys()))
|
| 283 |
# print("models_list", models_list)
|
| 284 |
# print("model_name", model_name)
|
| 285 |
+
|
| 286 |
+
# 处理带斜杠的模型名
|
| 287 |
+
for provider in config['providers']:
|
| 288 |
+
if model in provider['model'].keys():
|
| 289 |
+
provider_rules.append(provider['provider'] + "/" + model)
|
| 290 |
+
|
| 291 |
# print("model", model)
|
| 292 |
+
if (model_name_split and model_name in models_list) or (model_name_split == "*" and model_name in models_list):
|
| 293 |
provider_rules.append(provider_name)
|
| 294 |
else:
|
| 295 |
for provider in config['providers']:
|
| 296 |
if model in provider['model'].keys():
|
| 297 |
+
provider_rules.append(provider['provider'] + "/" + model_name_split)
|
| 298 |
|
| 299 |
provider_list = []
|
| 300 |
# print("provider_rules", provider_rules)
|
|
|
|
| 303 |
# print("provider", provider, provider['provider'] == item, item)
|
| 304 |
if "/" in item:
|
| 305 |
if provider['provider'] == item.split("/")[0]:
|
| 306 |
+
if model_name in provider['model'].keys() and "/".join(item.split("/")[1:]) == model_name:
|
| 307 |
provider_list.append(provider)
|
| 308 |
elif provider['provider'] == item:
|
| 309 |
if model_name in provider['model'].keys():
|
|
|
|
| 428 |
|
| 429 |
rate_limiter = InMemoryRateLimiter()
|
| 430 |
|
| 431 |
+
async def get_user_rate_limit(api_index: str = None):
|
| 432 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
| 433 |
# 示例: 返回 (次数, 秒数)
|
| 434 |
config = app.state.config
|
|
|
|
|
|
|
| 435 |
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")
|
| 436 |
|
| 437 |
+
if not api_index or not raw_rate_limit:
|
| 438 |
return (60, 60)
|
| 439 |
|
| 440 |
rate_limit = parse_rate_limit(raw_rate_limit)
|
|
|
|
| 443 |
security = HTTPBearer()
|
| 444 |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 445 |
token = credentials.credentials if credentials else None
|
| 446 |
+
api_list = app.state.api_list
|
| 447 |
+
try:
|
| 448 |
+
api_index = api_list.index(token)
|
| 449 |
+
except ValueError:
|
| 450 |
+
print("error: Invalid or missing API Key:", token)
|
| 451 |
+
api_index = None
|
| 452 |
+
token = None
|
| 453 |
+
limit, period = await get_user_rate_limit(api_index)
|
| 454 |
|
| 455 |
# 使用 IP 地址和 token(如果有)作为限制键
|
| 456 |
client_ip = request.client.host
|