🐛 Bug: Fix the bug where weight polling did not check if the weight channel conforms to the request model.
Browse files
main.py
CHANGED
|
@@ -647,98 +647,105 @@ def lottery_scheduling(weights):
|
|
| 647 |
break
|
| 648 |
return selections
|
| 649 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
import asyncio
|
| 651 |
class ModelRequestHandler:
|
| 652 |
def __init__(self):
|
| 653 |
self.last_provider_indices = defaultdict(lambda: -1)
|
| 654 |
self.locks = defaultdict(asyncio.Lock)
|
| 655 |
|
| 656 |
-
def
|
| 657 |
config = app.state.config
|
| 658 |
-
# api_keys_db = app.state.api_keys_db
|
| 659 |
api_list = app.state.api_list
|
| 660 |
api_index = api_list.index(token)
|
|
|
|
| 661 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
| 662 |
raise HTTPException(status_code=404, detail="No matching model found")
|
| 663 |
-
provider_rules = []
|
| 664 |
-
|
| 665 |
-
for model in config['api_keys'][api_index]['model']:
|
| 666 |
-
if model == "all":
|
| 667 |
-
# 如果模型名为 all,则返回所有模型
|
| 668 |
-
for provider in config["providers"]:
|
| 669 |
-
model_dict = get_model_dict(provider)
|
| 670 |
-
for model in model_dict.keys():
|
| 671 |
-
provider_rules.append(provider["provider"] + "/" + model)
|
| 672 |
-
break
|
| 673 |
-
if "/" in model:
|
| 674 |
-
if model.startswith("<") and model.endswith(">"):
|
| 675 |
-
model = model[1:-1]
|
| 676 |
-
# 处理带斜杠的模型名
|
| 677 |
-
for provider in config['providers']:
|
| 678 |
-
model_dict = get_model_dict(provider)
|
| 679 |
-
if model in model_dict.keys():
|
| 680 |
-
provider_rules.append(provider['provider'] + "/" + model)
|
| 681 |
-
else:
|
| 682 |
-
provider_name = model.split("/")[0]
|
| 683 |
-
model_name_split = "/".join(model.split("/")[1:])
|
| 684 |
-
models_list = []
|
| 685 |
-
for provider in config['providers']:
|
| 686 |
-
model_dict = get_model_dict(provider)
|
| 687 |
-
if provider['provider'] == provider_name:
|
| 688 |
-
models_list.extend(list(model_dict.keys()))
|
| 689 |
-
# print("models_list", models_list)
|
| 690 |
-
# print("model_name", model_name)
|
| 691 |
-
# print("model_name_split", model_name_split)
|
| 692 |
-
# print("model", model)
|
| 693 |
-
|
| 694 |
-
# api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
|
| 695 |
-
if model_name_split == "*":
|
| 696 |
-
if model_name in models_list:
|
| 697 |
-
provider_rules.append(provider_name + "/" + model_name)
|
| 698 |
-
|
| 699 |
-
# 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
|
| 700 |
-
for models_list_model in models_list:
|
| 701 |
-
if model_name.endswith("*") and models_list_model.startswith(model_name.rstrip("*")):
|
| 702 |
-
provider_rules.append(provider_name + "/" + models_list_model)
|
| 703 |
-
|
| 704 |
-
# api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
|
| 705 |
-
elif model_name_split == model_name \
|
| 706 |
-
or (model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
|
| 707 |
-
if model_name_split in models_list:
|
| 708 |
-
provider_rules.append(provider_name + "/" + model_name_split)
|
| 709 |
-
|
| 710 |
-
else:
|
| 711 |
-
for provider in config["providers"]:
|
| 712 |
-
model_dict = get_model_dict(provider)
|
| 713 |
-
if model in model_dict.keys():
|
| 714 |
-
provider_rules.append(provider["provider"] + "/" + model)
|
| 715 |
-
|
| 716 |
-
provider_list = []
|
| 717 |
-
# print("provider_rules", provider_rules)
|
| 718 |
-
for item in provider_rules:
|
| 719 |
-
for provider in config['providers']:
|
| 720 |
-
if "/" in item and provider['provider'] == item.split("/")[0]:
|
| 721 |
-
new_provider = copy.deepcopy(provider)
|
| 722 |
-
model_dict = get_model_dict(provider)
|
| 723 |
-
model_name_split = "/".join(item.split("/")[1:])
|
| 724 |
-
# old: new
|
| 725 |
-
new_provider["model"] = [{model_dict[model_name_split]: model_name}]
|
| 726 |
-
if model_name in model_dict.keys() and model_name_split == model_name:
|
| 727 |
-
provider_list.append(new_provider)
|
| 728 |
-
|
| 729 |
-
elif model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*")):
|
| 730 |
-
provider_list.append(new_provider)
|
| 731 |
-
|
| 732 |
-
# print("provider_list", provider_list)
|
| 733 |
-
return provider_list
|
| 734 |
-
|
| 735 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
| 736 |
-
config = app.state.config
|
| 737 |
-
api_list = app.state.api_list
|
| 738 |
-
api_index = api_list.index(token)
|
| 739 |
|
| 740 |
-
|
| 741 |
-
matching_providers =
|
| 742 |
num_matching_providers = len(matching_providers)
|
| 743 |
|
| 744 |
if not matching_providers:
|
|
@@ -757,6 +764,13 @@ class ModelRequestHandler:
|
|
| 757 |
intersection = None
|
| 758 |
if weights and all_providers:
|
| 759 |
weight_keys = set(weights.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
# 步骤 3: ���算交集
|
| 761 |
intersection = all_providers.intersection(weight_keys)
|
| 762 |
|
|
@@ -769,6 +783,7 @@ class ModelRequestHandler:
|
|
| 769 |
weighted_provider_name_list = lottery_scheduling(weights)
|
| 770 |
else:
|
| 771 |
weighted_provider_name_list = list(weights.keys())
|
|
|
|
| 772 |
|
| 773 |
new_matching_providers = []
|
| 774 |
for provider_name in weighted_provider_name_list:
|
|
@@ -786,9 +801,9 @@ class ModelRequestHandler:
|
|
| 786 |
|
| 787 |
start_index = 0
|
| 788 |
if scheduling_algorithm != "fixed_priority":
|
| 789 |
-
async with self.locks[
|
| 790 |
-
self.last_provider_indices[
|
| 791 |
-
start_index = self.last_provider_indices[
|
| 792 |
|
| 793 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
| 794 |
|
|
|
|
| 647 |
break
|
| 648 |
return selections
|
| 649 |
|
| 650 |
+
def get_provider_rules(model_rule, config, request_model):
|
| 651 |
+
provider_rules = []
|
| 652 |
+
if model_rule == "all":
|
| 653 |
+
# 如果模型名为 all,则返回所有模型
|
| 654 |
+
for provider in config["providers"]:
|
| 655 |
+
model_dict = get_model_dict(provider)
|
| 656 |
+
for model in model_dict.keys():
|
| 657 |
+
provider_rules.append(provider["provider"] + "/" + model)
|
| 658 |
+
|
| 659 |
+
elif "/" in model_rule:
|
| 660 |
+
if model_rule.startswith("<") and model_rule.endswith(">"):
|
| 661 |
+
model_rule = model_rule[1:-1]
|
| 662 |
+
# 处理带斜杠的模型名
|
| 663 |
+
for provider in config['providers']:
|
| 664 |
+
model_dict = get_model_dict(provider)
|
| 665 |
+
if model_rule in model_dict.keys():
|
| 666 |
+
provider_rules.append(provider['provider'] + "/" + model_rule)
|
| 667 |
+
else:
|
| 668 |
+
provider_name = model_rule.split("/")[0]
|
| 669 |
+
model_name_split = "/".join(model_rule.split("/")[1:])
|
| 670 |
+
models_list = []
|
| 671 |
+
for provider in config['providers']:
|
| 672 |
+
model_dict = get_model_dict(provider)
|
| 673 |
+
if provider['provider'] == provider_name:
|
| 674 |
+
models_list.extend(list(model_dict.keys()))
|
| 675 |
+
# print("models_list", models_list)
|
| 676 |
+
# print("model_name", model_name)
|
| 677 |
+
# print("model_name_split", model_name_split)
|
| 678 |
+
# print("model", model)
|
| 679 |
+
|
| 680 |
+
# api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
|
| 681 |
+
if model_name_split == "*":
|
| 682 |
+
if request_model in models_list:
|
| 683 |
+
provider_rules.append(provider_name + "/" + request_model)
|
| 684 |
+
|
| 685 |
+
# 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
|
| 686 |
+
for models_list_model in models_list:
|
| 687 |
+
if request_model.endswith("*") and models_list_model.startswith(request_model.rstrip("*")):
|
| 688 |
+
provider_rules.append(provider_name + "/" + models_list_model)
|
| 689 |
+
|
| 690 |
+
# api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
|
| 691 |
+
elif model_name_split == request_model \
|
| 692 |
+
or (request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
|
| 693 |
+
if model_name_split in models_list:
|
| 694 |
+
provider_rules.append(provider_name + "/" + model_name_split)
|
| 695 |
+
|
| 696 |
+
else:
|
| 697 |
+
for provider in config["providers"]:
|
| 698 |
+
model_dict = get_model_dict(provider)
|
| 699 |
+
if model_rule in model_dict.keys():
|
| 700 |
+
provider_rules.append(provider["provider"] + "/" + model_rule)
|
| 701 |
+
|
| 702 |
+
return provider_rules
|
| 703 |
+
|
| 704 |
+
def get_provider_list(provider_rules, config, request_model):
|
| 705 |
+
provider_list = []
|
| 706 |
+
# print("provider_rules", provider_rules)
|
| 707 |
+
for item in provider_rules:
|
| 708 |
+
for provider in config['providers']:
|
| 709 |
+
if "/" in item and provider['provider'] == item.split("/")[0]:
|
| 710 |
+
new_provider = copy.deepcopy(provider)
|
| 711 |
+
model_dict = get_model_dict(provider)
|
| 712 |
+
model_name_split = "/".join(item.split("/")[1:])
|
| 713 |
+
# old: new
|
| 714 |
+
new_provider["model"] = [{model_dict[model_name_split]: request_model}]
|
| 715 |
+
if request_model in model_dict.keys() and model_name_split == request_model:
|
| 716 |
+
provider_list.append(new_provider)
|
| 717 |
+
|
| 718 |
+
elif request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*")):
|
| 719 |
+
provider_list.append(new_provider)
|
| 720 |
+
return provider_list
|
| 721 |
+
|
| 722 |
+
def get_matching_providers(request_model, config, api_index):
|
| 723 |
+
provider_rules = []
|
| 724 |
+
|
| 725 |
+
for model_rule in config['api_keys'][api_index]['model']:
|
| 726 |
+
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
| 727 |
+
|
| 728 |
+
provider_list = get_provider_list(provider_rules, config, request_model)
|
| 729 |
+
|
| 730 |
+
# print("provider_list", provider_list)
|
| 731 |
+
return provider_list
|
| 732 |
+
|
| 733 |
import asyncio
|
| 734 |
class ModelRequestHandler:
|
| 735 |
def __init__(self):
|
| 736 |
self.last_provider_indices = defaultdict(lambda: -1)
|
| 737 |
self.locks = defaultdict(asyncio.Lock)
|
| 738 |
|
| 739 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
| 740 |
config = app.state.config
|
|
|
|
| 741 |
api_list = app.state.api_list
|
| 742 |
api_index = api_list.index(token)
|
| 743 |
+
|
| 744 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
| 745 |
raise HTTPException(status_code=404, detail="No matching model found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
|
| 747 |
+
request_model = request.model
|
| 748 |
+
matching_providers = get_matching_providers(request_model, config, api_index)
|
| 749 |
num_matching_providers = len(matching_providers)
|
| 750 |
|
| 751 |
if not matching_providers:
|
|
|
|
| 764 |
intersection = None
|
| 765 |
if weights and all_providers:
|
| 766 |
weight_keys = set(weights.keys())
|
| 767 |
+
provider_rules = []
|
| 768 |
+
for model_rule in weight_keys:
|
| 769 |
+
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
| 770 |
+
provider_list = get_provider_list(provider_rules, config, request_model)
|
| 771 |
+
weight_keys = set([provider['provider'] for provider in provider_list])
|
| 772 |
+
# print("all_providers", all_providers)
|
| 773 |
+
# print("weights", weight_keys)
|
| 774 |
# 步骤 3: ���算交集
|
| 775 |
intersection = all_providers.intersection(weight_keys)
|
| 776 |
|
|
|
|
| 783 |
weighted_provider_name_list = lottery_scheduling(weights)
|
| 784 |
else:
|
| 785 |
weighted_provider_name_list = list(weights.keys())
|
| 786 |
+
# print("weighted_provider_name_list", weighted_provider_name_list)
|
| 787 |
|
| 788 |
new_matching_providers = []
|
| 789 |
for provider_name in weighted_provider_name_list:
|
|
|
|
| 801 |
|
| 802 |
start_index = 0
|
| 803 |
if scheduling_algorithm != "fixed_priority":
|
| 804 |
+
async with self.locks[request_model]:
|
| 805 |
+
self.last_provider_indices[request_model] = (self.last_provider_indices[request_model] + 1) % num_matching_providers
|
| 806 |
+
start_index = self.last_provider_indices[request_model]
|
| 807 |
|
| 808 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
| 809 |
|
utils.py
CHANGED
|
@@ -109,9 +109,9 @@ def update_config(config_data, use_config_url=False):
|
|
| 109 |
for model in api_key.get('model'):
|
| 110 |
if isinstance(model, dict):
|
| 111 |
key, value = list(model.items())[0]
|
| 112 |
-
provider_name = key.split("/")[0]
|
| 113 |
if "/" in key:
|
| 114 |
-
weights_dict.update({
|
| 115 |
models.append(key)
|
| 116 |
if isinstance(model, str):
|
| 117 |
models.append(model)
|
|
|
|
| 109 |
for model in api_key.get('model'):
|
| 110 |
if isinstance(model, dict):
|
| 111 |
key, value = list(model.items())[0]
|
| 112 |
+
# provider_name = key.split("/")[0]
|
| 113 |
if "/" in key:
|
| 114 |
+
weights_dict.update({key: int(value)})
|
| 115 |
models.append(key)
|
| 116 |
if isinstance(model, str):
|
| 117 |
models.append(model)
|