✨ Feature: Add feature: Support model-level cooling. When a model under a channel reports an error, it does not affect other models under the same channel, only the model that reported the error is cooled.
Browse files
main.py
CHANGED
|
@@ -161,35 +161,43 @@ async def parse_request_body(request: Request):
|
|
| 161 |
|
| 162 |
class ChannelManager:
|
| 163 |
def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
|
| 164 |
-
self.
|
| 165 |
self._lock = asyncio.Lock()
|
| 166 |
self.cooldown_period = cooldown_period
|
| 167 |
|
| 168 |
-
async def
|
| 169 |
-
"""
|
| 170 |
async with self._lock:
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
-
async def
|
| 174 |
-
"""
|
| 175 |
async with self._lock:
|
| 176 |
-
|
|
|
|
| 177 |
return False
|
| 178 |
|
| 179 |
-
excluded_time = self.
|
| 180 |
if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
|
| 181 |
# 已超过冷却时间,移除限制
|
| 182 |
-
del self.
|
| 183 |
return False
|
| 184 |
return True
|
| 185 |
|
| 186 |
async def get_available_providers(self, providers: list) -> list:
|
| 187 |
-
"""过滤出可用的providers"""
|
| 188 |
available_providers = []
|
| 189 |
for provider in providers:
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
available_providers.append(provider)
|
|
|
|
| 193 |
return available_providers
|
| 194 |
|
| 195 |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
@@ -992,7 +1000,9 @@ class ModelRequestHandler:
|
|
| 992 |
|
| 993 |
channel_id = f"{provider['provider']}"
|
| 994 |
if app.state.channel_manager.cooldown_period > 0:
|
| 995 |
-
|
|
|
|
|
|
|
| 996 |
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
| 997 |
num_matching_providers = len(matching_providers)
|
| 998 |
index = 0
|
|
|
|
| 161 |
|
| 162 |
class ChannelManager:
|
| 163 |
def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
|
| 164 |
+
self._excluded_models: Dict[str, datetime] = {}
|
| 165 |
self._lock = asyncio.Lock()
|
| 166 |
self.cooldown_period = cooldown_period
|
| 167 |
|
| 168 |
+
async def exclude_model(self, provider: str, model: str):
|
| 169 |
+
"""将特定渠道下的特定模型添加到排除列表"""
|
| 170 |
async with self._lock:
|
| 171 |
+
model_key = f"{provider}/{model}"
|
| 172 |
+
self._excluded_models[model_key] = datetime.now()
|
| 173 |
|
| 174 |
+
async def is_model_excluded(self, provider: str, model: str) -> bool:
|
| 175 |
+
"""检查特定渠道下的特定模型是否被排除"""
|
| 176 |
async with self._lock:
|
| 177 |
+
model_key = f"{provider}/{model}"
|
| 178 |
+
if model_key not in self._excluded_models:
|
| 179 |
return False
|
| 180 |
|
| 181 |
+
excluded_time = self._excluded_models[model_key]
|
| 182 |
if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
|
| 183 |
# 已超过冷却时间,移除限制
|
| 184 |
+
del self._excluded_models[model_key]
|
| 185 |
return False
|
| 186 |
return True
|
| 187 |
|
| 188 |
async def get_available_providers(self, providers: list) -> list:
|
| 189 |
+
"""过滤出可用的providers,仅排除不可用的模型"""
|
| 190 |
available_providers = []
|
| 191 |
for provider in providers:
|
| 192 |
+
provider_name = provider['provider']
|
| 193 |
+
model_dict = provider['model'][0] # 获取唯一的模型字典
|
| 194 |
+
source_model = list(model_dict.keys())[0] # 源模型名称
|
| 195 |
+
# target_model = list(model_dict.values())[0] # 目标模型名称
|
| 196 |
+
|
| 197 |
+
# 检查该模型是否被排除
|
| 198 |
+
if not await self.is_model_excluded(provider_name, source_model):
|
| 199 |
available_providers.append(provider)
|
| 200 |
+
|
| 201 |
return available_providers
|
| 202 |
|
| 203 |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
|
|
| 1000 |
|
| 1001 |
channel_id = f"{provider['provider']}"
|
| 1002 |
if app.state.channel_manager.cooldown_period > 0:
|
| 1003 |
+
# 获取源模型名称(实际配置的模型名)
|
| 1004 |
+
source_model = list(provider['model'][0].keys())[0]
|
| 1005 |
+
await app.state.channel_manager.exclude_model(channel_id, source_model)
|
| 1006 |
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
| 1007 |
num_matching_providers = len(matching_providers)
|
| 1008 |
index = 0
|