🐛 Bug: Fixed the bug where the cooling model could not take effect in this request.
Browse files
main.py
CHANGED
|
@@ -866,48 +866,28 @@ def get_matching_providers(request_model, config, api_index):
|
|
| 866 |
# print("provider_list", provider_list)
|
| 867 |
return provider_list
|
| 868 |
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
def __init__(self):
|
| 872 |
-
self.last_provider_indices = defaultdict(lambda: -1)
|
| 873 |
-
self.locks = defaultdict(asyncio.Lock)
|
| 874 |
-
|
| 875 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
| 876 |
-
config = app.state.config
|
| 877 |
-
api_list = app.state.api_list
|
| 878 |
-
api_index = api_list.index(token)
|
| 879 |
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
request_model = request.model
|
| 884 |
-
matching_providers = get_matching_providers(request_model, config, api_index)
|
| 885 |
|
|
|
|
|
|
|
| 886 |
if not matching_providers:
|
| 887 |
-
raise HTTPException(status_code=
|
| 888 |
-
|
| 889 |
-
if app.state.channel_manager.cooldown_period > 0:
|
| 890 |
-
matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
|
| 891 |
-
if not matching_providers:
|
| 892 |
-
raise HTTPException(status_code=503, detail="No available providers at the moment")
|
| 893 |
|
|
|
|
|
|
|
| 894 |
num_matching_providers = len(matching_providers)
|
|
|
|
| 895 |
|
|
|
|
| 896 |
|
| 897 |
-
|
| 898 |
-
scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority")
|
| 899 |
-
if scheduling_algorithm == "random":
|
| 900 |
-
matching_providers = random.sample(matching_providers, num_matching_providers)
|
| 901 |
-
|
| 902 |
-
weights = safe_get(config, 'api_keys', api_index, "weights")
|
| 903 |
-
|
| 904 |
-
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
| 905 |
-
# print("matching_providers", matching_providers)
|
| 906 |
-
# print(type(matching_providers[0]['model'][0].keys()), list(matching_providers[0]['model'][0].keys())[0], matching_providers[0]['model'][0].keys())
|
| 907 |
-
all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
|
| 908 |
-
|
| 909 |
intersection = None
|
| 910 |
-
|
|
|
|
| 911 |
weight_keys = set(weights.keys())
|
| 912 |
provider_rules = []
|
| 913 |
for model_rule in weight_keys:
|
|
@@ -922,7 +902,7 @@ class ModelRequestHandler:
|
|
| 922 |
intersection = all_providers.intersection(weight_keys)
|
| 923 |
# print("intersection", intersection)
|
| 924 |
|
| 925 |
-
if
|
| 926 |
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
|
| 927 |
# print("filtered_weights", filtered_weights)
|
| 928 |
|
|
@@ -941,9 +921,31 @@ class ModelRequestHandler:
|
|
| 941 |
new_matching_providers.append(provider)
|
| 942 |
matching_providers = new_matching_providers
|
| 943 |
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
|
| 948 |
status_code = 500
|
| 949 |
error_message = None
|
|
@@ -956,8 +958,12 @@ class ModelRequestHandler:
|
|
| 956 |
|
| 957 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
| 958 |
|
| 959 |
-
|
| 960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
provider = matching_providers[current_index]
|
| 962 |
try:
|
| 963 |
response = await process_request(request, provider, endpoint, token)
|
|
@@ -987,6 +993,10 @@ class ModelRequestHandler:
|
|
| 987 |
channel_id = f"{provider['provider']}"
|
| 988 |
if app.state.channel_manager.cooldown_period > 0:
|
| 989 |
await app.state.channel_manager.exclude_channel(channel_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 990 |
logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
|
| 991 |
if is_debug:
|
| 992 |
import traceback
|
|
|
|
| 866 |
# print("provider_list", provider_list)
|
| 867 |
return provider_list
|
| 868 |
|
| 869 |
+
async def get_right_order_providers(request_model, config, api_index, scheduling_algorithm):
|
| 870 |
+
matching_providers = get_matching_providers(request_model, config, api_index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
+
if not matching_providers:
|
| 873 |
+
raise HTTPException(status_code=404, detail="No matching model found")
|
|
|
|
|
|
|
|
|
|
| 874 |
|
| 875 |
+
if app.state.channel_manager.cooldown_period > 0:
|
| 876 |
+
matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
|
| 877 |
if not matching_providers:
|
| 878 |
+
raise HTTPException(status_code=503, detail="No available providers at the moment")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
|
| 880 |
+
# 检查是否启用轮询
|
| 881 |
+
if scheduling_algorithm == "random":
|
| 882 |
num_matching_providers = len(matching_providers)
|
| 883 |
+
matching_providers = random.sample(matching_providers, num_matching_providers)
|
| 884 |
|
| 885 |
+
weights = safe_get(config, 'api_keys', api_index, "weights")
|
| 886 |
|
| 887 |
+
if weights:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
intersection = None
|
| 889 |
+
all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
|
| 890 |
+
if all_providers:
|
| 891 |
weight_keys = set(weights.keys())
|
| 892 |
provider_rules = []
|
| 893 |
for model_rule in weight_keys:
|
|
|
|
| 902 |
intersection = all_providers.intersection(weight_keys)
|
| 903 |
# print("intersection", intersection)
|
| 904 |
|
| 905 |
+
if intersection:
|
| 906 |
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
|
| 907 |
# print("filtered_weights", filtered_weights)
|
| 908 |
|
|
|
|
| 921 |
new_matching_providers.append(provider)
|
| 922 |
matching_providers = new_matching_providers
|
| 923 |
|
| 924 |
+
if is_debug:
|
| 925 |
+
for provider in matching_providers:
|
| 926 |
+
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
| 927 |
+
|
| 928 |
+
return matching_providers
|
| 929 |
+
|
| 930 |
+
import asyncio
|
| 931 |
+
class ModelRequestHandler:
|
| 932 |
+
def __init__(self):
|
| 933 |
+
self.last_provider_indices = defaultdict(lambda: -1)
|
| 934 |
+
self.locks = defaultdict(asyncio.Lock)
|
| 935 |
+
|
| 936 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
| 937 |
+
config = app.state.config
|
| 938 |
+
api_list = app.state.api_list
|
| 939 |
+
api_index = api_list.index(token)
|
| 940 |
+
|
| 941 |
+
if not safe_get(config, 'api_keys', api_index, 'model'):
|
| 942 |
+
raise HTTPException(status_code=404, detail="No matching model found")
|
| 943 |
+
|
| 944 |
+
request_model = request.model
|
| 945 |
+
scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority")
|
| 946 |
+
|
| 947 |
+
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
| 948 |
+
num_matching_providers = len(matching_providers)
|
| 949 |
|
| 950 |
status_code = 500
|
| 951 |
error_message = None
|
|
|
|
| 958 |
|
| 959 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
| 960 |
|
| 961 |
+
index = 0
|
| 962 |
+
while True:
|
| 963 |
+
if index >= num_matching_providers:
|
| 964 |
+
break
|
| 965 |
+
current_index = (start_index + index) % num_matching_providers
|
| 966 |
+
index += 1
|
| 967 |
provider = matching_providers[current_index]
|
| 968 |
try:
|
| 969 |
response = await process_request(request, provider, endpoint, token)
|
|
|
|
| 993 |
channel_id = f"{provider['provider']}"
|
| 994 |
if app.state.channel_manager.cooldown_period > 0:
|
| 995 |
await app.state.channel_manager.exclude_channel(channel_id)
|
| 996 |
+
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
| 997 |
+
num_matching_providers = len(matching_providers)
|
| 998 |
+
index = 0
|
| 999 |
+
|
| 1000 |
logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
|
| 1001 |
if is_debug:
|
| 1002 |
import traceback
|