dray
commited on
Commit
·
3e27f73
1
Parent(s):
331ab5a
Refactor: Add support for retrieving all models when the model name is "*"
Browse files
main.py
CHANGED
|
@@ -419,6 +419,12 @@ class ModelRequestHandler:
|
|
| 419 |
provider_rules = []
|
| 420 |
|
| 421 |
for model in config['api_keys'][api_index]['model']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
if "/" in model:
|
| 423 |
if model.startswith("<") and model.endswith(">"):
|
| 424 |
model = model[1:-1]
|
|
|
|
| 419 |
provider_rules = []
|
| 420 |
|
| 421 |
for model in config['api_keys'][api_index]['model']:
|
| 422 |
+
if model == "*":
|
| 423 |
+
# 如果模型名为 *,则返回所有模型
|
| 424 |
+
for provider in config["providers"]:
|
| 425 |
+
for model in provider["model"].keys():
|
| 426 |
+
provider_rules.append(provider["provider"] + "/" + model)
|
| 427 |
+
break
|
| 428 |
if "/" in model:
|
| 429 |
if model.startswith("<") and model.endswith(">"):
|
| 430 |
model = model[1:-1]
|
utils.py
CHANGED
|
@@ -62,7 +62,7 @@ async def load_config(app=None):
|
|
| 62 |
# is_quoted = not token.plain
|
| 63 |
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")
|
| 64 |
|
| 65 |
-
with open(
|
| 66 |
# 判断是否为空文件
|
| 67 |
conf = yaml.safe_load(f)
|
| 68 |
# conf = None
|
|
@@ -170,6 +170,10 @@ def post_all_models(token, config, api_list):
|
|
| 170 |
api_index = api_list.index(token)
|
| 171 |
if config['api_keys'][api_index]['model']:
|
| 172 |
for model in config['api_keys'][api_index]['model']:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if "/" in model:
|
| 174 |
provider = model.split("/")[0]
|
| 175 |
model = model.split("/")[1]
|
|
|
|
| 62 |
# is_quoted = not token.plain
|
| 63 |
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")
|
| 64 |
|
| 65 |
+
with open("./api.yaml", "r", encoding="utf-8") as f:
|
| 66 |
# 判断是否为空文件
|
| 67 |
conf = yaml.safe_load(f)
|
| 68 |
# conf = None
|
|
|
|
| 170 |
api_index = api_list.index(token)
|
| 171 |
if config['api_keys'][api_index]['model']:
|
| 172 |
for model in config['api_keys'][api_index]['model']:
|
| 173 |
+
if model == "*":
|
| 174 |
+
# 如果模型名为 *,则返回所有模型
|
| 175 |
+
all_models = get_all_models(config)
|
| 176 |
+
return all_models
|
| 177 |
if "/" in model:
|
| 178 |
provider = model.split("/")[0]
|
| 179 |
model = model.split("/")[1]
|