✨ Feature: 1. Add support for experimental frontend.
Browse files2. Add support for configuration files without listing models.
- .gitignore +1 -1
- components/provider_table.py +230 -0
- main.py +24 -15
- models.py +10 -4
- request.py +42 -28
- requirements.txt +2 -1
- test/test_ruamel_yaml.py +44 -0
- test/xue/test_dropdown_sheet.py +75 -0
- test/xue/test_form_uni_api.py +116 -0
- test/xue/test_home.py +476 -0
- utils.py +80 -61
.gitignore
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
api.json
|
| 2 |
-
|
| 3 |
.env
|
| 4 |
__pycache__
|
| 5 |
.vscode
|
|
|
|
| 1 |
api.json
|
| 2 |
+
*.yaml
|
| 3 |
.env
|
| 4 |
__pycache__
|
| 5 |
.vscode
|
components/provider_table.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from xue import Div, Table, Thead, Tbody, Tr, Th, Td, Button, Input, Script, Head, Style, Span
|
| 2 |
+
from xue.components.checkbox import checkbox
|
| 3 |
+
from xue.components.dropdown import dropdown_menu, dropdown_menu_content
|
| 4 |
+
from xue.components.button import button
|
| 5 |
+
from xue.components.input import input
|
| 6 |
+
|
| 7 |
+
Head.add_default_children([
|
| 8 |
+
Style("""
|
| 9 |
+
.data-table-container {
|
| 10 |
+
width: 100%;
|
| 11 |
+
overflow-x: auto;
|
| 12 |
+
border: 1px solid #e2e8f0;
|
| 13 |
+
border-radius: 0.5rem;
|
| 14 |
+
overflow-x: visible !important;
|
| 15 |
+
}
|
| 16 |
+
.data-table {
|
| 17 |
+
width: 100%;
|
| 18 |
+
border-collapse: separate;
|
| 19 |
+
border-spacing: 0;
|
| 20 |
+
}
|
| 21 |
+
.data-table th, .data-table td {
|
| 22 |
+
padding: 0.75rem 1rem;
|
| 23 |
+
text-align: left;
|
| 24 |
+
border-bottom: 1px solid #e2e8f0;
|
| 25 |
+
}
|
| 26 |
+
.data-table th {
|
| 27 |
+
font-weight: 500;
|
| 28 |
+
font-size: 0.875rem;
|
| 29 |
+
color: #4b5563;
|
| 30 |
+
height: 2.5rem;
|
| 31 |
+
transition: background-color 0.2s;
|
| 32 |
+
}
|
| 33 |
+
.data-table thead tr:hover th,
|
| 34 |
+
.data-table tbody tr:hover {
|
| 35 |
+
background-color: #f8fafc;
|
| 36 |
+
}
|
| 37 |
+
.data-table tbody tr:last-child td {
|
| 38 |
+
border-bottom: none;
|
| 39 |
+
}
|
| 40 |
+
.sortable-header {
|
| 41 |
+
cursor: pointer;
|
| 42 |
+
user-select: none;
|
| 43 |
+
display: inline-flex;
|
| 44 |
+
align-items: center;
|
| 45 |
+
padding: 0.25rem 0.5rem;
|
| 46 |
+
border-radius: 0.25rem;
|
| 47 |
+
transition: background-color 0.2s;
|
| 48 |
+
}
|
| 49 |
+
.sortable-header:hover {
|
| 50 |
+
background-color: #e5e7eb;
|
| 51 |
+
}
|
| 52 |
+
.sort-icon {
|
| 53 |
+
display: inline-block;
|
| 54 |
+
width: 1rem;
|
| 55 |
+
height: 1rem;
|
| 56 |
+
margin-left: 0.25rem;
|
| 57 |
+
transition: transform 0.2s;
|
| 58 |
+
opacity: 0;
|
| 59 |
+
}
|
| 60 |
+
.sortable-header:hover .sort-icon,
|
| 61 |
+
.sort-asc .sort-icon,
|
| 62 |
+
.sort-desc .sort-icon {
|
| 63 |
+
opacity: 1;
|
| 64 |
+
}
|
| 65 |
+
.sort-asc .sort-icon {
|
| 66 |
+
transform: rotate(180deg);
|
| 67 |
+
}
|
| 68 |
+
.table-header {
|
| 69 |
+
display: flex;
|
| 70 |
+
justify-content: space-between;
|
| 71 |
+
align-items: center;
|
| 72 |
+
margin-bottom: 1rem;
|
| 73 |
+
}
|
| 74 |
+
.table-footer {
|
| 75 |
+
display: flex;
|
| 76 |
+
justify-content: space-between;
|
| 77 |
+
align-items: center;
|
| 78 |
+
margin-top: 1rem;
|
| 79 |
+
}
|
| 80 |
+
.pagination {
|
| 81 |
+
display: flex;
|
| 82 |
+
gap: 0.5rem;
|
| 83 |
+
}
|
| 84 |
+
@media (prefers-color-scheme: dark) {
|
| 85 |
+
.data-table-container {
|
| 86 |
+
border-color: #4b5563;
|
| 87 |
+
}
|
| 88 |
+
.data-table th, .data-table td {
|
| 89 |
+
border-color: #4b5563;
|
| 90 |
+
}
|
| 91 |
+
.data-table th {
|
| 92 |
+
color: #d1d5db;
|
| 93 |
+
}
|
| 94 |
+
.data-table thead tr:hover th,
|
| 95 |
+
.data-table tbody tr:hover {
|
| 96 |
+
background-color: #1f2937;
|
| 97 |
+
}
|
| 98 |
+
.sortable-header:hover {
|
| 99 |
+
background-color: #374151;
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
""", id="data-table-style"),
|
| 103 |
+
Script("""
|
| 104 |
+
function toggleAllRows(checked) {
|
| 105 |
+
const checkboxes = document.querySelectorAll('.row-checkbox');
|
| 106 |
+
checkboxes.forEach(cb => cb.checked = checked);
|
| 107 |
+
updateSelectedCount();
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
function updateSelectedCount() {
|
| 111 |
+
const selectedCount = document.querySelectorAll('.row-checkbox:checked').length;
|
| 112 |
+
const totalCount = document.querySelectorAll('.row-checkbox').length;
|
| 113 |
+
document.getElementById('selected-count').textContent = `${selectedCount} of ${totalCount} row(s) selected.`;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
function sortTable(columnIndex, accessor) {
|
| 117 |
+
const table = document.querySelector('.data-table');
|
| 118 |
+
const header = table.querySelector(`th[data-accessor="${accessor}"]`);
|
| 119 |
+
const isAscending = !header.classList.contains('sort-asc');
|
| 120 |
+
|
| 121 |
+
// Update sort direction
|
| 122 |
+
table.querySelectorAll('th').forEach(th => th.classList.remove('sort-asc', 'sort-desc'));
|
| 123 |
+
header.classList.add(isAscending ? 'sort-asc' : 'sort-desc');
|
| 124 |
+
|
| 125 |
+
// Sort the table
|
| 126 |
+
const rows = Array.from(table.querySelectorAll('tbody tr'));
|
| 127 |
+
rows.sort((a, b) => {
|
| 128 |
+
const aValue = a.querySelector(`td[data-accessor="${accessor}"]`).textContent;
|
| 129 |
+
const bValue = b.querySelector(`td[data-accessor="${accessor}"]`).textContent;
|
| 130 |
+
return isAscending ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
|
| 131 |
+
});
|
| 132 |
+
|
| 133 |
+
// Update the table
|
| 134 |
+
const tbody = table.querySelector('tbody');
|
| 135 |
+
rows.forEach(row => tbody.appendChild(row));
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
document.addEventListener('change', function(event) {
|
| 139 |
+
if (event.target.classList.contains('row-checkbox')) {
|
| 140 |
+
updateSelectedCount();
|
| 141 |
+
}
|
| 142 |
+
});
|
| 143 |
+
""", id="data-table-script"),
|
| 144 |
+
])
|
| 145 |
+
|
| 146 |
+
def data_table(columns, data, id, with_filter=True):
|
| 147 |
+
return Div(
|
| 148 |
+
Div(
|
| 149 |
+
input(type="text", placeholder="Filter...", id=f"{id}-filter", class_="mr-auto"),
|
| 150 |
+
Div(
|
| 151 |
+
button(
|
| 152 |
+
"Add Provider",
|
| 153 |
+
variant="secondary",
|
| 154 |
+
hx_get="/add-provider-sheet",
|
| 155 |
+
hx_target="#sheet-container",
|
| 156 |
+
hx_swap="innerHTML",
|
| 157 |
+
class_="h-[2.625rem]"
|
| 158 |
+
),
|
| 159 |
+
dropdown_menu("Columns"),
|
| 160 |
+
),
|
| 161 |
+
class_="table-header flex items-center"
|
| 162 |
+
) if with_filter else None,
|
| 163 |
+
Div(
|
| 164 |
+
Div(
|
| 165 |
+
Table(
|
| 166 |
+
Thead(
|
| 167 |
+
Tr(
|
| 168 |
+
Th(checkbox("select-all", "", onclick="toggleAllRows(this.checked)")),
|
| 169 |
+
*[Th(
|
| 170 |
+
Div(
|
| 171 |
+
col['label'],
|
| 172 |
+
Span("▼", class_="sort-icon"),
|
| 173 |
+
class_="sortable-header" if col.get('sortable', False) else "",
|
| 174 |
+
onclick=f"sortTable({i}, '{col['value']}')" if col.get('sortable', False) else None
|
| 175 |
+
),
|
| 176 |
+
data_accessor=col['value']
|
| 177 |
+
) for i, col in enumerate(columns)],
|
| 178 |
+
Th("Actions") # 新增的操作列
|
| 179 |
+
)
|
| 180 |
+
),
|
| 181 |
+
Tbody(
|
| 182 |
+
*[Tr(
|
| 183 |
+
Td(checkbox(f"row-{i}", "", class_="row-checkbox")),
|
| 184 |
+
*[Td(row[col['value']], data_accessor=col['value']) for col in columns],
|
| 185 |
+
Td(row_actions_menu(i)), # 使用行索引作为 row_id
|
| 186 |
+
id=f"row-{i}"
|
| 187 |
+
) for i, row in enumerate(data)]
|
| 188 |
+
),
|
| 189 |
+
class_="data-table"
|
| 190 |
+
),
|
| 191 |
+
class_="data-table-container"
|
| 192 |
+
),
|
| 193 |
+
Div(
|
| 194 |
+
Div(id="selected-count", class_="text-sm text-gray-500"),
|
| 195 |
+
Div(
|
| 196 |
+
button("Previous", variant="outline", class_="mr-2"),
|
| 197 |
+
button("Next", variant="outline"),
|
| 198 |
+
class_="pagination"
|
| 199 |
+
),
|
| 200 |
+
class_="table-footer"
|
| 201 |
+
),
|
| 202 |
+
id=id
|
| 203 |
+
),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def get_column_visibility_menu(id, columns):
|
| 207 |
+
return dropdown_menu_content(id, [
|
| 208 |
+
{"label": col['label'], "value": col['value']}
|
| 209 |
+
for col in columns if col.get('can_hide', True)
|
| 210 |
+
])
|
| 211 |
+
|
| 212 |
+
def row_actions_menu(row_id):
|
| 213 |
+
return dropdown_menu("⋮", id=f"row-actions-menu-{row_id}", hx_get=f"/dropdown-menu/dropdown-menu-⋮/{row_id}")
|
| 214 |
+
|
| 215 |
+
def get_row_actions_menu(row_id):
|
| 216 |
+
return dropdown_menu_content(f"row-actions-{row_id}", [
|
| 217 |
+
{"label": "Edit", "icon": "pencil"},
|
| 218 |
+
{"label": "Duplicate", "icon": "copy"},
|
| 219 |
+
{"label": "Delete", "icon": "trash"},
|
| 220 |
+
"separator",
|
| 221 |
+
{"label": "More...", "icon": "more-horizontal"},
|
| 222 |
+
])
|
| 223 |
+
|
| 224 |
+
def render_row(row_data, row_id, columns):
|
| 225 |
+
return Tr(
|
| 226 |
+
Td(checkbox(f"row-{row_id}", "", class_="row-checkbox")),
|
| 227 |
+
*[Td(row_data[col['value']], data_accessor=col['value']) for col in columns],
|
| 228 |
+
Td(row_actions_menu(row_id)),
|
| 229 |
+
id=f"row-{row_id}"
|
| 230 |
+
).render()
|
main.py
CHANGED
|
@@ -18,7 +18,7 @@ from fastapi.exceptions import RequestValidationError
|
|
| 18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
| 19 |
from request import get_payload
|
| 20 |
from response import fetch_response, fetch_response_stream
|
| 21 |
-
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
| 22 |
|
| 23 |
from collections import defaultdict
|
| 24 |
from typing import List, Dict, Union
|
|
@@ -492,20 +492,21 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 492 |
else:
|
| 493 |
engine = "gpt"
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
and "
|
|
|
|
| 498 |
and parsed_url.netloc != 'api.cloudflare.com' \
|
| 499 |
and parsed_url.netloc != 'api.cohere.com':
|
| 500 |
engine = "openrouter"
|
| 501 |
|
| 502 |
-
if "claude" in
|
| 503 |
engine = "vertex-claude"
|
| 504 |
|
| 505 |
-
if "gemini" in
|
| 506 |
engine = "vertex-gemini"
|
| 507 |
|
| 508 |
-
if "o1-preview" in
|
| 509 |
engine = "o1"
|
| 510 |
request.stream = False
|
| 511 |
|
|
@@ -536,7 +537,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 536 |
current_info = request_info.get()
|
| 537 |
try:
|
| 538 |
if request.stream:
|
| 539 |
-
model =
|
| 540 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
| 541 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
| 542 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
|
@@ -603,7 +604,8 @@ class ModelRequestHandler:
|
|
| 603 |
if model == "all":
|
| 604 |
# 如果模型名为 *,则返回所有模型
|
| 605 |
for provider in config["providers"]:
|
| 606 |
-
|
|
|
|
| 607 |
provider_rules.append(provider["provider"] + "/" + model)
|
| 608 |
break
|
| 609 |
if "/" in model:
|
|
@@ -611,15 +613,17 @@ class ModelRequestHandler:
|
|
| 611 |
model = model[1:-1]
|
| 612 |
# 处理带斜杠的模型名
|
| 613 |
for provider in config['providers']:
|
| 614 |
-
|
|
|
|
| 615 |
provider_rules.append(provider['provider'] + "/" + model)
|
| 616 |
else:
|
| 617 |
provider_name = model.split("/")[0]
|
| 618 |
model_name_split = "/".join(model.split("/")[1:])
|
| 619 |
models_list = []
|
| 620 |
for provider in config['providers']:
|
|
|
|
| 621 |
if provider['provider'] == provider_name:
|
| 622 |
-
models_list.extend(list(
|
| 623 |
# print("models_list", models_list)
|
| 624 |
# print("model_name", model_name)
|
| 625 |
# print("model_name_split", model_name_split)
|
|
@@ -632,7 +636,8 @@ class ModelRequestHandler:
|
|
| 632 |
provider_rules.append(provider_name)
|
| 633 |
else:
|
| 634 |
for provider in config['providers']:
|
| 635 |
-
|
|
|
|
| 636 |
provider_rules.append(provider['provider'] + "/" + model)
|
| 637 |
|
| 638 |
provider_list = []
|
|
@@ -642,10 +647,13 @@ class ModelRequestHandler:
|
|
| 642 |
# print("provider", provider, provider['provider'] == item, item)
|
| 643 |
if "/" in item:
|
| 644 |
if provider['provider'] == item.split("/")[0]:
|
| 645 |
-
|
|
|
|
| 646 |
provider_list.append(provider)
|
|
|
|
| 647 |
elif provider['provider'] == item:
|
| 648 |
-
|
|
|
|
| 649 |
provider_list.append(provider)
|
| 650 |
else:
|
| 651 |
pass
|
|
@@ -655,7 +663,8 @@ class ModelRequestHandler:
|
|
| 655 |
# if item.split("/")[1] == model_name:
|
| 656 |
# provider_list.append(provider)
|
| 657 |
# else:
|
| 658 |
-
#
|
|
|
|
| 659 |
# provider_list.append(provider)
|
| 660 |
if is_debug:
|
| 661 |
for provider in provider_list:
|
|
|
|
| 18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
| 19 |
from request import get_payload
|
| 20 |
from response import fetch_response, fetch_response_stream
|
| 21 |
+
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict
|
| 22 |
|
| 23 |
from collections import defaultdict
|
| 24 |
from typing import List, Dict, Union
|
|
|
|
| 492 |
else:
|
| 493 |
engine = "gpt"
|
| 494 |
|
| 495 |
+
model_dict = get_model_dict(provider)
|
| 496 |
+
if "claude" not in model_dict[request.model] \
|
| 497 |
+
and "gpt" not in model_dict[request.model] \
|
| 498 |
+
and "gemini" not in model_dict[request.model] \
|
| 499 |
and parsed_url.netloc != 'api.cloudflare.com' \
|
| 500 |
and parsed_url.netloc != 'api.cohere.com':
|
| 501 |
engine = "openrouter"
|
| 502 |
|
| 503 |
+
if "claude" in model_dict[request.model] and engine == "vertex":
|
| 504 |
engine = "vertex-claude"
|
| 505 |
|
| 506 |
+
if "gemini" in model_dict[request.model] and engine == "vertex":
|
| 507 |
engine = "vertex-gemini"
|
| 508 |
|
| 509 |
+
if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]:
|
| 510 |
engine = "o1"
|
| 511 |
request.stream = False
|
| 512 |
|
|
|
|
| 537 |
current_info = request_info.get()
|
| 538 |
try:
|
| 539 |
if request.stream:
|
| 540 |
+
model = model_dict[request.model]
|
| 541 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
| 542 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
| 543 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
|
|
|
| 604 |
if model == "all":
|
| 605 |
# 如果模型名为 *,则返回所有模型
|
| 606 |
for provider in config["providers"]:
|
| 607 |
+
model_dict = get_model_dict(provider)
|
| 608 |
+
for model in model_dict.keys():
|
| 609 |
provider_rules.append(provider["provider"] + "/" + model)
|
| 610 |
break
|
| 611 |
if "/" in model:
|
|
|
|
| 613 |
model = model[1:-1]
|
| 614 |
# 处理带斜杠的模型名
|
| 615 |
for provider in config['providers']:
|
| 616 |
+
model_dict = get_model_dict(provider)
|
| 617 |
+
if model in model_dict.keys():
|
| 618 |
provider_rules.append(provider['provider'] + "/" + model)
|
| 619 |
else:
|
| 620 |
provider_name = model.split("/")[0]
|
| 621 |
model_name_split = "/".join(model.split("/")[1:])
|
| 622 |
models_list = []
|
| 623 |
for provider in config['providers']:
|
| 624 |
+
model_dict = get_model_dict(provider)
|
| 625 |
if provider['provider'] == provider_name:
|
| 626 |
+
models_list.extend(list(model_dict.keys()))
|
| 627 |
# print("models_list", models_list)
|
| 628 |
# print("model_name", model_name)
|
| 629 |
# print("model_name_split", model_name_split)
|
|
|
|
| 636 |
provider_rules.append(provider_name)
|
| 637 |
else:
|
| 638 |
for provider in config['providers']:
|
| 639 |
+
model_dict = get_model_dict(provider)
|
| 640 |
+
if model in model_dict.keys():
|
| 641 |
provider_rules.append(provider['provider'] + "/" + model)
|
| 642 |
|
| 643 |
provider_list = []
|
|
|
|
| 647 |
# print("provider", provider, provider['provider'] == item, item)
|
| 648 |
if "/" in item:
|
| 649 |
if provider['provider'] == item.split("/")[0]:
|
| 650 |
+
model_dict = get_model_dict(provider)
|
| 651 |
+
if model_name in model_dict.keys() and "/".join(item.split("/")[1:]) == model_name:
|
| 652 |
provider_list.append(provider)
|
| 653 |
+
# 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
|
| 654 |
elif provider['provider'] == item:
|
| 655 |
+
model_dict = get_model_dict(provider)
|
| 656 |
+
if model_name in model_dict.keys():
|
| 657 |
provider_list.append(provider)
|
| 658 |
else:
|
| 659 |
pass
|
|
|
|
| 663 |
# if item.split("/")[1] == model_name:
|
| 664 |
# provider_list.append(provider)
|
| 665 |
# else:
|
| 666 |
+
# model_dict = get_model_dict(provider)
|
| 667 |
+
# if model_name in model_dict.keys():
|
| 668 |
# provider_list.append(provider)
|
| 669 |
if is_debug:
|
| 670 |
for provider in provider_list:
|
models.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from io import IOBase
|
| 2 |
-
from pydantic import BaseModel, Field, model_validator
|
| 3 |
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
|
| 4 |
from log_config import logger
|
| 5 |
|
|
@@ -61,10 +61,16 @@ class ToolChoice(BaseModel):
|
|
| 61 |
class BaseRequest(BaseModel):
|
| 62 |
request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
|
|
|
| 68 |
class ResponseFormat(BaseModel):
|
| 69 |
type: Literal["text", "json_object", "json_schema"]
|
| 70 |
json_schema: Optional[JsonSchema] = None
|
|
|
|
| 1 |
from io import IOBase
|
| 2 |
+
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
| 3 |
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
|
| 4 |
from log_config import logger
|
| 5 |
|
|
|
|
| 61 |
class BaseRequest(BaseModel):
|
| 62 |
request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
|
| 63 |
|
| 64 |
+
def create_json_schema_class():
|
| 65 |
+
class JsonSchema(BaseModel):
|
| 66 |
+
name: str
|
| 67 |
+
|
| 68 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 69 |
+
|
| 70 |
+
JsonSchema.__annotations__['schema'] = Dict[str, Any]
|
| 71 |
+
return JsonSchema
|
| 72 |
|
| 73 |
+
JsonSchema = create_json_schema_class()
|
| 74 |
class ResponseFormat(BaseModel):
|
| 75 |
type: Literal["text", "json_object", "json_schema"]
|
| 76 |
json_schema: Optional[JsonSchema] = None
|
request.py
CHANGED
|
@@ -6,7 +6,7 @@ import base64
|
|
| 6 |
import urllib.parse
|
| 7 |
|
| 8 |
from models import RequestModel
|
| 9 |
-
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
|
| 10 |
|
| 11 |
import imghdr
|
| 12 |
|
|
@@ -120,13 +120,14 @@ async def get_gemini_payload(request, engine, provider):
|
|
| 120 |
headers = {
|
| 121 |
'Content-Type': 'application/json'
|
| 122 |
}
|
| 123 |
-
|
|
|
|
| 124 |
gemini_stream = "streamGenerateContent"
|
| 125 |
url = provider['base_url']
|
| 126 |
if url.endswith("v1beta"):
|
| 127 |
-
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['
|
| 128 |
if url.endswith("v1"):
|
| 129 |
-
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['
|
| 130 |
|
| 131 |
messages = []
|
| 132 |
systemInstruction = None
|
|
@@ -312,7 +313,8 @@ async def get_vertex_gemini_payload(request, engine, provider):
|
|
| 312 |
project_id = provider.get("project_id")
|
| 313 |
|
| 314 |
gemini_stream = "streamGenerateContent"
|
| 315 |
-
|
|
|
|
| 316 |
location = gem
|
| 317 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
|
| 318 |
|
|
@@ -449,7 +451,8 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
| 449 |
if provider.get("project_id"):
|
| 450 |
project_id = provider.get("project_id")
|
| 451 |
|
| 452 |
-
|
|
|
|
| 453 |
if "claude-3-5-sonnet" in model:
|
| 454 |
location = c35s
|
| 455 |
elif "claude-3-opus" in model:
|
|
@@ -460,7 +463,7 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
| 460 |
location = c3h
|
| 461 |
|
| 462 |
claude_stream = "streamRawPredict"
|
| 463 |
-
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
| 464 |
|
| 465 |
messages = []
|
| 466 |
system_prompt = None
|
|
@@ -534,7 +537,8 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
| 534 |
else:
|
| 535 |
message_index = message_index + 1
|
| 536 |
|
| 537 |
-
|
|
|
|
| 538 |
payload = {
|
| 539 |
"anthropic_version": "vertex-2023-10-16",
|
| 540 |
"messages": messages,
|
|
@@ -593,7 +597,7 @@ async def get_gpt_payload(request, engine, provider):
|
|
| 593 |
'Content-Type': 'application/json',
|
| 594 |
}
|
| 595 |
if provider.get("api"):
|
| 596 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 597 |
url = provider['base_url']
|
| 598 |
|
| 599 |
messages = []
|
|
@@ -633,7 +637,8 @@ async def get_gpt_payload(request, engine, provider):
|
|
| 633 |
else:
|
| 634 |
messages.append({"role": msg.role, "content": content})
|
| 635 |
|
| 636 |
-
|
|
|
|
| 637 |
payload = {
|
| 638 |
"model": model,
|
| 639 |
"messages": messages,
|
|
@@ -659,7 +664,7 @@ async def get_openrouter_payload(request, engine, provider):
|
|
| 659 |
'Content-Type': 'application/json'
|
| 660 |
}
|
| 661 |
if provider.get("api"):
|
| 662 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 663 |
|
| 664 |
url = provider['base_url']
|
| 665 |
|
|
@@ -691,7 +696,8 @@ async def get_openrouter_payload(request, engine, provider):
|
|
| 691 |
else:
|
| 692 |
messages.append({"role": msg.role, "content": content})
|
| 693 |
|
| 694 |
-
|
|
|
|
| 695 |
payload = {
|
| 696 |
"model": model,
|
| 697 |
"messages": messages,
|
|
@@ -725,7 +731,7 @@ async def get_cohere_payload(request, engine, provider):
|
|
| 725 |
'Content-Type': 'application/json'
|
| 726 |
}
|
| 727 |
if provider.get("api"):
|
| 728 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 729 |
|
| 730 |
url = provider['base_url']
|
| 731 |
|
|
@@ -753,7 +759,8 @@ async def get_cohere_payload(request, engine, provider):
|
|
| 753 |
else:
|
| 754 |
messages.append({"role": role_map[msg.role], "message": content})
|
| 755 |
|
| 756 |
-
|
|
|
|
| 757 |
chat_history = messages[:-1]
|
| 758 |
query = messages[-1].get("message")
|
| 759 |
payload = {
|
|
@@ -792,9 +799,10 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
| 792 |
'Content-Type': 'application/json'
|
| 793 |
}
|
| 794 |
if provider.get("api"):
|
| 795 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 796 |
|
| 797 |
-
|
|
|
|
| 798 |
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
|
| 799 |
|
| 800 |
msg = request.messages[-1]
|
|
@@ -808,7 +816,7 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
| 808 |
content = msg.content
|
| 809 |
name = msg.name
|
| 810 |
|
| 811 |
-
model =
|
| 812 |
payload = {
|
| 813 |
"prompt": content,
|
| 814 |
}
|
|
@@ -841,7 +849,7 @@ async def get_o1_payload(request, engine, provider):
|
|
| 841 |
'Content-Type': 'application/json'
|
| 842 |
}
|
| 843 |
if provider.get("api"):
|
| 844 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 845 |
|
| 846 |
url = provider['base_url']
|
| 847 |
|
|
@@ -863,7 +871,8 @@ async def get_o1_payload(request, engine, provider):
|
|
| 863 |
elif msg.role != "system":
|
| 864 |
messages.append({"role": msg.role, "content": content})
|
| 865 |
|
| 866 |
-
|
|
|
|
| 867 |
payload = {
|
| 868 |
"model": model,
|
| 869 |
"messages": messages,
|
|
@@ -912,10 +921,11 @@ async def gpt2claude_tools_json(json_dict):
|
|
| 912 |
return json_dict
|
| 913 |
|
| 914 |
async def get_claude_payload(request, engine, provider):
|
| 915 |
-
|
|
|
|
| 916 |
headers = {
|
| 917 |
"content-type": "application/json",
|
| 918 |
-
"x-api-key": f"{provider['
|
| 919 |
"anthropic-version": "2023-06-01",
|
| 920 |
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
|
| 921 |
}
|
|
@@ -993,7 +1003,8 @@ async def get_claude_payload(request, engine, provider):
|
|
| 993 |
else:
|
| 994 |
message_index = message_index + 1
|
| 995 |
|
| 996 |
-
|
|
|
|
| 997 |
payload = {
|
| 998 |
"model": model,
|
| 999 |
"messages": messages,
|
|
@@ -1051,12 +1062,13 @@ async def get_claude_payload(request, engine, provider):
|
|
| 1051 |
return url, headers, payload
|
| 1052 |
|
| 1053 |
async def get_dalle_payload(request, engine, provider):
|
| 1054 |
-
|
|
|
|
| 1055 |
headers = {
|
| 1056 |
"Content-Type": "application/json",
|
| 1057 |
}
|
| 1058 |
if provider.get("api"):
|
| 1059 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 1060 |
url = provider['base_url']
|
| 1061 |
url = BaseAPI(url).image_url
|
| 1062 |
|
|
@@ -1070,12 +1082,13 @@ async def get_dalle_payload(request, engine, provider):
|
|
| 1070 |
return url, headers, payload
|
| 1071 |
|
| 1072 |
async def get_whisper_payload(request, engine, provider):
|
| 1073 |
-
|
|
|
|
| 1074 |
headers = {
|
| 1075 |
# "Content-Type": "multipart/form-data",
|
| 1076 |
}
|
| 1077 |
if provider.get("api"):
|
| 1078 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 1079 |
url = provider['base_url']
|
| 1080 |
url = BaseAPI(url).audio_transcriptions
|
| 1081 |
|
|
@@ -1096,12 +1109,13 @@ async def get_whisper_payload(request, engine, provider):
|
|
| 1096 |
return url, headers, payload
|
| 1097 |
|
| 1098 |
async def get_moderation_payload(request, engine, provider):
|
| 1099 |
-
|
|
|
|
| 1100 |
headers = {
|
| 1101 |
"Content-Type": "application/json",
|
| 1102 |
}
|
| 1103 |
if provider.get("api"):
|
| 1104 |
-
headers['Authorization'] = f"Bearer {provider['
|
| 1105 |
url = provider['base_url']
|
| 1106 |
url = BaseAPI(url).moderations
|
| 1107 |
|
|
|
|
| 6 |
import urllib.parse
|
| 7 |
|
| 8 |
from models import RequestModel
|
| 9 |
+
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI, get_model_dict, provider_api_circular_list
|
| 10 |
|
| 11 |
import imghdr
|
| 12 |
|
|
|
|
| 120 |
headers = {
|
| 121 |
'Content-Type': 'application/json'
|
| 122 |
}
|
| 123 |
+
model_dict = get_model_dict(provider)
|
| 124 |
+
model = model_dict[request.model]
|
| 125 |
gemini_stream = "streamGenerateContent"
|
| 126 |
url = provider['base_url']
|
| 127 |
if url.endswith("v1beta"):
|
| 128 |
+
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
|
| 129 |
if url.endswith("v1"):
|
| 130 |
+
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
|
| 131 |
|
| 132 |
messages = []
|
| 133 |
systemInstruction = None
|
|
|
|
| 313 |
project_id = provider.get("project_id")
|
| 314 |
|
| 315 |
gemini_stream = "streamGenerateContent"
|
| 316 |
+
model_dict = get_model_dict(provider)
|
| 317 |
+
model = model_dict[request.model]
|
| 318 |
location = gem
|
| 319 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
|
| 320 |
|
|
|
|
| 451 |
if provider.get("project_id"):
|
| 452 |
project_id = provider.get("project_id")
|
| 453 |
|
| 454 |
+
model_dict = get_model_dict(provider)
|
| 455 |
+
model = model_dict[request.model]
|
| 456 |
if "claude-3-5-sonnet" in model:
|
| 457 |
location = c35s
|
| 458 |
elif "claude-3-opus" in model:
|
|
|
|
| 463 |
location = c3h
|
| 464 |
|
| 465 |
claude_stream = "streamRawPredict"
|
| 466 |
+
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=await location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
| 467 |
|
| 468 |
messages = []
|
| 469 |
system_prompt = None
|
|
|
|
| 537 |
else:
|
| 538 |
message_index = message_index + 1
|
| 539 |
|
| 540 |
+
model_dict = get_model_dict(provider)
|
| 541 |
+
model = model_dict[request.model]
|
| 542 |
payload = {
|
| 543 |
"anthropic_version": "vertex-2023-10-16",
|
| 544 |
"messages": messages,
|
|
|
|
| 597 |
'Content-Type': 'application/json',
|
| 598 |
}
|
| 599 |
if provider.get("api"):
|
| 600 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 601 |
url = provider['base_url']
|
| 602 |
|
| 603 |
messages = []
|
|
|
|
| 637 |
else:
|
| 638 |
messages.append({"role": msg.role, "content": content})
|
| 639 |
|
| 640 |
+
model_dict = get_model_dict(provider)
|
| 641 |
+
model = model_dict[request.model]
|
| 642 |
payload = {
|
| 643 |
"model": model,
|
| 644 |
"messages": messages,
|
|
|
|
| 664 |
'Content-Type': 'application/json'
|
| 665 |
}
|
| 666 |
if provider.get("api"):
|
| 667 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 668 |
|
| 669 |
url = provider['base_url']
|
| 670 |
|
|
|
|
| 696 |
else:
|
| 697 |
messages.append({"role": msg.role, "content": content})
|
| 698 |
|
| 699 |
+
model_dict = get_model_dict(provider)
|
| 700 |
+
model = model_dict[request.model]
|
| 701 |
payload = {
|
| 702 |
"model": model,
|
| 703 |
"messages": messages,
|
|
|
|
| 731 |
'Content-Type': 'application/json'
|
| 732 |
}
|
| 733 |
if provider.get("api"):
|
| 734 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 735 |
|
| 736 |
url = provider['base_url']
|
| 737 |
|
|
|
|
| 759 |
else:
|
| 760 |
messages.append({"role": role_map[msg.role], "message": content})
|
| 761 |
|
| 762 |
+
model_dict = get_model_dict(provider)
|
| 763 |
+
model = model_dict[request.model]
|
| 764 |
chat_history = messages[:-1]
|
| 765 |
query = messages[-1].get("message")
|
| 766 |
payload = {
|
|
|
|
| 799 |
'Content-Type': 'application/json'
|
| 800 |
}
|
| 801 |
if provider.get("api"):
|
| 802 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 803 |
|
| 804 |
+
model_dict = get_model_dict(provider)
|
| 805 |
+
model = model_dict[request.model]
|
| 806 |
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
|
| 807 |
|
| 808 |
msg = request.messages[-1]
|
|
|
|
| 816 |
content = msg.content
|
| 817 |
name = msg.name
|
| 818 |
|
| 819 |
+
model = model_dict[request.model]
|
| 820 |
payload = {
|
| 821 |
"prompt": content,
|
| 822 |
}
|
|
|
|
| 849 |
'Content-Type': 'application/json'
|
| 850 |
}
|
| 851 |
if provider.get("api"):
|
| 852 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 853 |
|
| 854 |
url = provider['base_url']
|
| 855 |
|
|
|
|
| 871 |
elif msg.role != "system":
|
| 872 |
messages.append({"role": msg.role, "content": content})
|
| 873 |
|
| 874 |
+
model_dict = get_model_dict(provider)
|
| 875 |
+
model = model_dict[request.model]
|
| 876 |
payload = {
|
| 877 |
"model": model,
|
| 878 |
"messages": messages,
|
|
|
|
| 921 |
return json_dict
|
| 922 |
|
| 923 |
async def get_claude_payload(request, engine, provider):
|
| 924 |
+
model_dict = get_model_dict(provider)
|
| 925 |
+
model = model_dict[request.model]
|
| 926 |
headers = {
|
| 927 |
"content-type": "application/json",
|
| 928 |
+
"x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}",
|
| 929 |
"anthropic-version": "2023-06-01",
|
| 930 |
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
|
| 931 |
}
|
|
|
|
| 1003 |
else:
|
| 1004 |
message_index = message_index + 1
|
| 1005 |
|
| 1006 |
+
model_dict = get_model_dict(provider)
|
| 1007 |
+
model = model_dict[request.model]
|
| 1008 |
payload = {
|
| 1009 |
"model": model,
|
| 1010 |
"messages": messages,
|
|
|
|
| 1062 |
return url, headers, payload
|
| 1063 |
|
| 1064 |
async def get_dalle_payload(request, engine, provider):
|
| 1065 |
+
model_dict = get_model_dict(provider)
|
| 1066 |
+
model = model_dict[request.model]
|
| 1067 |
headers = {
|
| 1068 |
"Content-Type": "application/json",
|
| 1069 |
}
|
| 1070 |
if provider.get("api"):
|
| 1071 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 1072 |
url = provider['base_url']
|
| 1073 |
url = BaseAPI(url).image_url
|
| 1074 |
|
|
|
|
| 1082 |
return url, headers, payload
|
| 1083 |
|
| 1084 |
async def get_whisper_payload(request, engine, provider):
|
| 1085 |
+
model_dict = get_model_dict(provider)
|
| 1086 |
+
model = model_dict[request.model]
|
| 1087 |
headers = {
|
| 1088 |
# "Content-Type": "multipart/form-data",
|
| 1089 |
}
|
| 1090 |
if provider.get("api"):
|
| 1091 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 1092 |
url = provider['base_url']
|
| 1093 |
url = BaseAPI(url).audio_transcriptions
|
| 1094 |
|
|
|
|
| 1109 |
return url, headers, payload
|
| 1110 |
|
| 1111 |
async def get_moderation_payload(request, engine, provider):
|
| 1112 |
+
model_dict = get_model_dict(provider)
|
| 1113 |
+
model = model_dict[request.model]
|
| 1114 |
headers = {
|
| 1115 |
"Content-Type": "application/json",
|
| 1116 |
}
|
| 1117 |
if provider.get("api"):
|
| 1118 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 1119 |
url = provider['base_url']
|
| 1120 |
url = BaseAPI(url).moderations
|
| 1121 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
pytest
|
| 3 |
uvicorn
|
| 4 |
fastapi
|
|
@@ -7,6 +7,7 @@ greenlet
|
|
| 7 |
aiosqlite
|
| 8 |
sqlalchemy
|
| 9 |
watchfiles
|
|
|
|
| 10 |
httpx[http2]
|
| 11 |
cryptography
|
| 12 |
python-multipart
|
|
|
|
| 1 |
+
xue
|
| 2 |
pytest
|
| 3 |
uvicorn
|
| 4 |
fastapi
|
|
|
|
| 7 |
aiosqlite
|
| 8 |
sqlalchemy
|
| 9 |
watchfiles
|
| 10 |
+
ruamel.yaml
|
| 11 |
httpx[http2]
|
| 12 |
cryptography
|
| 13 |
python-multipart
|
test/test_ruamel_yaml.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ruamel.yaml import YAML
|
| 2 |
+
|
| 3 |
+
# 假设我们有以下 YAML 内容
|
| 4 |
+
yaml_content = """
|
| 5 |
+
# 这是顶级注释
|
| 6 |
+
key1: value1 # 行尾注释
|
| 7 |
+
key2: value2
|
| 8 |
+
|
| 9 |
+
# 这是嵌套结构的注释
|
| 10 |
+
nested:
|
| 11 |
+
subkey1: subvalue1
|
| 12 |
+
subkey2: subvalue2 # 嵌套的行尾注释
|
| 13 |
+
|
| 14 |
+
# 列表的注释
|
| 15 |
+
list_key:
|
| 16 |
+
- item1
|
| 17 |
+
- item2 # 列表项的注释
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# 创建 YAML 对象
|
| 21 |
+
yaml = YAML()
|
| 22 |
+
yaml.preserve_quotes = True
|
| 23 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
| 24 |
+
|
| 25 |
+
with open('api.yaml', 'r', encoding='utf-8') as file:
|
| 26 |
+
data = yaml.load(file)
|
| 27 |
+
|
| 28 |
+
# data = yaml.load(yaml_content)
|
| 29 |
+
# 加载 YAML 数据
|
| 30 |
+
print(data)
|
| 31 |
+
|
| 32 |
+
# # 修改数据
|
| 33 |
+
# data['key1'] = 'new_value1'
|
| 34 |
+
# data['nested']['subkey1'] = 'new_subvalue1'
|
| 35 |
+
# data['list_key'].append('new_item')
|
| 36 |
+
|
| 37 |
+
# 将修改后的数据写回文件(这里我们使用 StringIO 来模拟文件操作)
|
| 38 |
+
# from io import StringIO
|
| 39 |
+
# output = StringIO()
|
| 40 |
+
# yaml.dump(data, output)
|
| 41 |
+
# print(output.getvalue())
|
| 42 |
+
|
| 43 |
+
with open('formatted.yaml', 'w', encoding='utf-8') as file:
|
| 44 |
+
yaml.dump(data, file)
|
test/xue/test_dropdown_sheet.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.responses import HTMLResponse
|
| 3 |
+
from xue import HTML, Head, Body, Div, xue_initialize, Script
|
| 4 |
+
from xue.components import dropdown, sheet, button, form, input
|
| 5 |
+
|
| 6 |
+
xue_initialize(tailwind=True)
|
| 7 |
+
app = FastAPI()
|
| 8 |
+
|
| 9 |
+
@app.get("/", response_class=HTMLResponse)
|
| 10 |
+
async def root():
|
| 11 |
+
result = HTML(
|
| 12 |
+
Head(
|
| 13 |
+
title="Dropdown with Edit Sheet Example",
|
| 14 |
+
),
|
| 15 |
+
Body(
|
| 16 |
+
Div(
|
| 17 |
+
dropdown.dropdown_menu("Actions"),
|
| 18 |
+
Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
|
| 19 |
+
class_="container mx-auto p-4"
|
| 20 |
+
)
|
| 21 |
+
)
|
| 22 |
+
).render()
|
| 23 |
+
print(result)
|
| 24 |
+
return result
|
| 25 |
+
|
| 26 |
+
@app.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse)
|
| 27 |
+
async def get_dropdown_menu_content(menu_id: str):
|
| 28 |
+
items = [
|
| 29 |
+
{
|
| 30 |
+
"icon": "pencil",
|
| 31 |
+
"label": "Edit",
|
| 32 |
+
"hx-get": "/edit-sheet",
|
| 33 |
+
"hx-target": "#sheet-container",
|
| 34 |
+
"hx-swap": "innerHTML"
|
| 35 |
+
},
|
| 36 |
+
{"icon": "trash", "label": "Delete"},
|
| 37 |
+
{"icon": "copy", "label": "Duplicate"},
|
| 38 |
+
]
|
| 39 |
+
result = dropdown.dropdown_menu_content(menu_id, items).render()
|
| 40 |
+
print("dropdown-menu result", result)
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
+
@app.get("/edit-sheet", response_class=HTMLResponse)
|
| 44 |
+
async def get_edit_sheet():
|
| 45 |
+
edit_sheet_content = sheet.SheetContent(
|
| 46 |
+
sheet.SheetHeader(
|
| 47 |
+
sheet.SheetTitle("Edit Item"),
|
| 48 |
+
sheet.SheetDescription("Make changes to your item here.")
|
| 49 |
+
),
|
| 50 |
+
sheet.SheetBody(
|
| 51 |
+
form.Form(
|
| 52 |
+
form.FormField("Name", "name", placeholder="Enter item name"),
|
| 53 |
+
form.FormField("Description", "description", placeholder="Enter item description"),
|
| 54 |
+
Div(
|
| 55 |
+
button.button("Save", class_="bg-blue-500 text-white"),
|
| 56 |
+
button.button("Cancel", class_="bg-gray-300 text-gray-700 ml-2", data_close_sheet="true"),
|
| 57 |
+
class_="flex justify-end mt-4"
|
| 58 |
+
),
|
| 59 |
+
class_="space-y-4"
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
result = sheet.Sheet(
|
| 65 |
+
"edit-sheet",
|
| 66 |
+
Div(),
|
| 67 |
+
edit_sheet_content,
|
| 68 |
+
width="80%",
|
| 69 |
+
max_width="800px"
|
| 70 |
+
).render()
|
| 71 |
+
return result
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
import uvicorn
|
| 75 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
test/xue/test_form_uni_api.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Form as FastAPIForm
|
| 2 |
+
from fastapi.responses import HTMLResponse
|
| 3 |
+
from xue import HTML, Head, Body, Div, xue_initialize, Strong, Span, Ul, Li
|
| 4 |
+
from xue.components import form, button, checkbox, input
|
| 5 |
+
from xue.components.model_config_row import model_config_row
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
xue_initialize(tailwind=True)
|
| 10 |
+
app = FastAPI()
|
| 11 |
+
|
| 12 |
+
@app.get("/", response_class=HTMLResponse)
|
| 13 |
+
async def root():
|
| 14 |
+
result = HTML(
|
| 15 |
+
Head(
|
| 16 |
+
title="Provider Configuration Form"
|
| 17 |
+
),
|
| 18 |
+
Body(
|
| 19 |
+
Div(
|
| 20 |
+
form.Form(
|
| 21 |
+
form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
|
| 22 |
+
form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
|
| 23 |
+
form.FormField("API Key", "api_key", type="password", placeholder="Enter API key"),
|
| 24 |
+
Div(
|
| 25 |
+
Div("Models", class_="text-lg font-semibold mb-2"),
|
| 26 |
+
Div(
|
| 27 |
+
model_config_row("model1", "gpt-4o: deepbricks-gpt-4o-mini", True),
|
| 28 |
+
model_config_row("model2", "gpt-4o"),
|
| 29 |
+
model_config_row("model3", "gpt-3.5-turbo"),
|
| 30 |
+
model_config_row("model4", "claude-3-5-sonnet-20240620: claude-3-5-sonnet"),
|
| 31 |
+
model_config_row("model5", "o1-mini-all"),
|
| 32 |
+
model_config_row("model6", "o1-preview-all"),
|
| 33 |
+
model_config_row("model7", "whisper-1"),
|
| 34 |
+
id="models-container"
|
| 35 |
+
),
|
| 36 |
+
button.button(
|
| 37 |
+
"Add Model",
|
| 38 |
+
class_="mt-2",
|
| 39 |
+
hx_post="/add-model",
|
| 40 |
+
hx_target="#models-container",
|
| 41 |
+
hx_swap="beforeend"
|
| 42 |
+
),
|
| 43 |
+
class_="mb-4"
|
| 44 |
+
),
|
| 45 |
+
Div(
|
| 46 |
+
checkbox.checkbox("tools", "Enable Tools", checked=True),
|
| 47 |
+
class_="mb-4"
|
| 48 |
+
),
|
| 49 |
+
form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
|
| 50 |
+
Div(
|
| 51 |
+
button.button("Submit", class_="bg-blue-500 text-white"),
|
| 52 |
+
button.button("Cancel", class_="bg-gray-300 text-gray-700 ml-2"),
|
| 53 |
+
class_="flex justify-end mt-4"
|
| 54 |
+
),
|
| 55 |
+
hx_post="/submit",
|
| 56 |
+
hx_swap="outerHTML",
|
| 57 |
+
class_="space-y-4"
|
| 58 |
+
),
|
| 59 |
+
class_="container mx-auto p-4 max-w-2xl"
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
).render()
|
| 63 |
+
print(result)
|
| 64 |
+
return result
|
| 65 |
+
|
| 66 |
+
@app.post("/add-model", response_class=HTMLResponse)
|
| 67 |
+
async def add_model():
|
| 68 |
+
new_model_id = f"model{hash(str(time.time()))}" # 生成一个唯一的ID
|
| 69 |
+
new_model = model_config_row(new_model_id).render()
|
| 70 |
+
return new_model
|
| 71 |
+
|
| 72 |
+
def form_success_message(provider, base_url, api_key, models, tools_enabled, notes):
|
| 73 |
+
return Div(
|
| 74 |
+
Strong("Success!", class_="font-bold"),
|
| 75 |
+
Span("Form submitted successfully.", class_="block sm:inline"),
|
| 76 |
+
Ul(
|
| 77 |
+
Li(f"Provider: {provider}"),
|
| 78 |
+
Li(f"Base URL: {base_url}"),
|
| 79 |
+
Li(f"API Key: {'*' * len(api_key)}"),
|
| 80 |
+
Li(f"Models: {', '.join(models)}"),
|
| 81 |
+
Li(f"Tools Enabled: {'Yes' if tools_enabled else 'No'}"),
|
| 82 |
+
Li(f"Notes: {notes}"),
|
| 83 |
+
class_="mt-3"
|
| 84 |
+
),
|
| 85 |
+
class_="bg-green-100 border border-green-400 text-green-700 px-4 py-3 rounded relative",
|
| 86 |
+
role="alert"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@app.post("/submit", response_class=HTMLResponse)
|
| 90 |
+
async def submit_form(
|
| 91 |
+
provider: str = FastAPIForm(...),
|
| 92 |
+
base_url: str = FastAPIForm(...),
|
| 93 |
+
api_key: str = FastAPIForm(...),
|
| 94 |
+
models: List[str] = FastAPIForm([]),
|
| 95 |
+
tools: Optional[str] = FastAPIForm(None),
|
| 96 |
+
notes: Optional[str] = FastAPIForm(None)
|
| 97 |
+
):
|
| 98 |
+
# 处理提交的数据
|
| 99 |
+
print(f"Received: provider={provider}, base_url={base_url}, api_key={api_key}")
|
| 100 |
+
print(f"Models: {models}")
|
| 101 |
+
print(f"Tools Enabled: {tools is not None}")
|
| 102 |
+
print(f"Notes: {notes}")
|
| 103 |
+
|
| 104 |
+
# 返回处理结果
|
| 105 |
+
return form_success_message(
|
| 106 |
+
provider,
|
| 107 |
+
base_url,
|
| 108 |
+
api_key,
|
| 109 |
+
models,
|
| 110 |
+
tools is not None,
|
| 111 |
+
notes or "No notes provided"
|
| 112 |
+
).render()
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
import uvicorn
|
| 116 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
test/xue/test_home.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request
|
| 2 |
+
from fastapi import Form as FastapiForm, HTTPException, Depends
|
| 3 |
+
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
| 4 |
+
from fastapi.security import APIKeyHeader
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
|
| 7 |
+
from xue import HTML, Head, Body, Div, xue_initialize, Script
|
| 8 |
+
from xue.components.menubar import (
|
| 9 |
+
Menubar, MenubarMenu, MenubarTrigger, MenubarContent,
|
| 10 |
+
MenubarItem, MenubarSeparator
|
| 11 |
+
)
|
| 12 |
+
from xue.components import input
|
| 13 |
+
from xue.components import dropdown, sheet, form, button, checkbox
|
| 14 |
+
from xue.components.model_config_row import model_config_row
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
import os
|
| 19 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
| 20 |
+
from components.provider_table import data_table
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from ruamel.yaml import YAML
|
| 24 |
+
yaml = YAML()
|
| 25 |
+
yaml.preserve_quotes = True
|
| 26 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
| 27 |
+
|
| 28 |
+
xue_initialize(tailwind=True)
|
| 29 |
+
|
| 30 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 31 |
+
import logging
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
class RequestBodyLoggerMiddleware(BaseHTTPMiddleware):
|
| 37 |
+
async def dispatch(self, request: Request, call_next):
|
| 38 |
+
if request.method == "POST" and request.url.path.startswith("/submit/"):
|
| 39 |
+
# if request.method == "POST":
|
| 40 |
+
body = await request.body()
|
| 41 |
+
logger.info(f"Request body for {request.url.path}: {body.decode()}")
|
| 42 |
+
|
| 43 |
+
response = await call_next(request)
|
| 44 |
+
return response
|
| 45 |
+
|
| 46 |
+
from utils import load_config
|
| 47 |
+
from contextlib import asynccontextmanager
|
| 48 |
+
@asynccontextmanager
|
| 49 |
+
async def lifespan(app: FastAPI):
|
| 50 |
+
# app.state.client = httpx.AsyncClient(timeout=timeout)
|
| 51 |
+
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config()
|
| 52 |
+
for item in app.state.api_keys_db:
|
| 53 |
+
if item.get("role") == "admin":
|
| 54 |
+
app.state.admin_api_key = item.get("api")
|
| 55 |
+
if not hasattr(app.state, "admin_api_key"):
|
| 56 |
+
if len(app.state.api_keys_db) >= 1:
|
| 57 |
+
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
|
| 58 |
+
else:
|
| 59 |
+
raise Exception("No admin API key found")
|
| 60 |
+
|
| 61 |
+
global data
|
| 62 |
+
# providers_data = app.state.config["providers"]
|
| 63 |
+
|
| 64 |
+
# print("data", data)
|
| 65 |
+
yield
|
| 66 |
+
# 关闭时的代码
|
| 67 |
+
await app.state.client.aclose()
|
| 68 |
+
|
| 69 |
+
app = FastAPI(lifespan=lifespan)
|
| 70 |
+
# app.add_middleware(RequestBodyLoggerMiddleware)
|
| 71 |
+
app.add_middleware(RequestBodyLoggerMiddleware)
|
| 72 |
+
|
| 73 |
+
data_table_columns = [
|
| 74 |
+
# {"label": "Status", "value": "status", "sortable": True},
|
| 75 |
+
{"label": "Provider", "value": "provider", "sortable": True},
|
| 76 |
+
{"label": "Base url", "value": "base_url", "sortable": True},
|
| 77 |
+
# {"label": "Engine", "value": "engine", "sortable": True},
|
| 78 |
+
{"label": "Tools", "value": "tools", "sortable": True},
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
API_KEY_NAME = "X-API-Key"
|
| 82 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
| 83 |
+
|
| 84 |
+
@app.get("/login", response_class=HTMLResponse)
|
| 85 |
+
async def login_page():
|
| 86 |
+
return HTML(
|
| 87 |
+
Head(title="登录"),
|
| 88 |
+
Body(
|
| 89 |
+
Div(
|
| 90 |
+
form.Form(
|
| 91 |
+
form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True),
|
| 92 |
+
Div(id="error-message", class_="text-red-500 mt-2"),
|
| 93 |
+
Div(
|
| 94 |
+
button.button("提交", variant="primary", type="submit"),
|
| 95 |
+
class_="flex justify-end mt-4"
|
| 96 |
+
),
|
| 97 |
+
hx_post="/verify-api-key",
|
| 98 |
+
hx_target="#error-message",
|
| 99 |
+
hx_swap="innerHTML",
|
| 100 |
+
class_="space-y-4"
|
| 101 |
+
),
|
| 102 |
+
class_="container mx-auto p-4 max-w-md"
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
).render()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@app.post("/verify-api-key", response_class=HTMLResponse)
|
| 109 |
+
async def verify_api_key(x_api_key: str = FastapiForm(...)):
|
| 110 |
+
if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
|
| 111 |
+
response = JSONResponse(content={"success": True})
|
| 112 |
+
response.headers["HX-Redirect"] = "/" # 添加这一行
|
| 113 |
+
response.set_cookie(
|
| 114 |
+
key="x_api_key",
|
| 115 |
+
value=x_api_key,
|
| 116 |
+
httponly=True,
|
| 117 |
+
max_age=1800, # 30分钟
|
| 118 |
+
secure=False, # 在开发环境中设置为False,生产环境中使用HTTPS时设置为True
|
| 119 |
+
samesite="lax" # 改为"lax"以允许重定向时携带cookie
|
| 120 |
+
)
|
| 121 |
+
return response
|
| 122 |
+
else:
|
| 123 |
+
return Div("无效的API密钥", class_="text-red-500").render()
|
| 124 |
+
|
| 125 |
+
async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)):
|
| 126 |
+
if not x_api_key:
|
| 127 |
+
x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key")
|
| 128 |
+
# print(f"Cookie x_api_key: {request.cookies.get('x_api_key')}") # 添加此行
|
| 129 |
+
# print(f"Query param x_api_key: {request.query_params.get('x_api_key')}") # 添加此行
|
| 130 |
+
# print(f"Header x_api_key: {x_api_key}") # 添加此��
|
| 131 |
+
# logger.info(f"x_api_key: {x_api_key} {x_api_key == 'your_admin_api_key'}")
|
| 132 |
+
|
| 133 |
+
if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
|
| 134 |
+
return x_api_key
|
| 135 |
+
else:
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
@app.get("/", response_class=HTMLResponse)
|
| 139 |
+
async def root(x_api_key: str = Depends(get_api_key)):
|
| 140 |
+
if not x_api_key:
|
| 141 |
+
return RedirectResponse(url="/login", status_code=303)
|
| 142 |
+
|
| 143 |
+
result = HTML(
|
| 144 |
+
Head(
|
| 145 |
+
Script("""
|
| 146 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 147 |
+
const filterInput = document.getElementById('users-table-filter');
|
| 148 |
+
filterInput.addEventListener('input', function() {
|
| 149 |
+
const filterValue = this.value;
|
| 150 |
+
htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table');
|
| 151 |
+
});
|
| 152 |
+
});
|
| 153 |
+
"""),
|
| 154 |
+
title="Menubar Example"
|
| 155 |
+
),
|
| 156 |
+
Body(
|
| 157 |
+
Div(
|
| 158 |
+
Menubar(
|
| 159 |
+
MenubarMenu(
|
| 160 |
+
MenubarTrigger("File", "file-menu"),
|
| 161 |
+
MenubarContent(
|
| 162 |
+
MenubarItem("New Tab", shortcut="⌘T"),
|
| 163 |
+
MenubarItem("New Window", shortcut="⌘N"),
|
| 164 |
+
MenubarItem("New Incognito Window", disabled=True),
|
| 165 |
+
MenubarSeparator(),
|
| 166 |
+
MenubarItem("Print...", shortcut="⌘P"),
|
| 167 |
+
),
|
| 168 |
+
id="file-menu"
|
| 169 |
+
),
|
| 170 |
+
MenubarMenu(
|
| 171 |
+
MenubarTrigger("Edit", "edit-menu"),
|
| 172 |
+
MenubarContent(
|
| 173 |
+
MenubarItem("Undo", shortcut="⌘Z"),
|
| 174 |
+
MenubarItem("Redo", shortcut="⇧⌘Z"),
|
| 175 |
+
MenubarSeparator(),
|
| 176 |
+
MenubarItem("Cut"),
|
| 177 |
+
MenubarItem("Copy"),
|
| 178 |
+
MenubarItem("Paste"),
|
| 179 |
+
),
|
| 180 |
+
id="edit-menu"
|
| 181 |
+
),
|
| 182 |
+
MenubarMenu(
|
| 183 |
+
MenubarTrigger("View", "view-menu"),
|
| 184 |
+
MenubarContent(
|
| 185 |
+
MenubarItem("Always Show Bookmarks Bar"),
|
| 186 |
+
MenubarItem("Always Show Full URLs"),
|
| 187 |
+
MenubarSeparator(),
|
| 188 |
+
MenubarItem("Reload", shortcut="⌘R"),
|
| 189 |
+
MenubarItem("Force Reload", shortcut="⇧⌘R", disabled=True),
|
| 190 |
+
MenubarSeparator(),
|
| 191 |
+
MenubarItem("Toggle Fullscreen"),
|
| 192 |
+
MenubarItem("Hide Sidebar"),
|
| 193 |
+
),
|
| 194 |
+
id="view-menu"
|
| 195 |
+
),
|
| 196 |
+
),
|
| 197 |
+
class_="p-4"
|
| 198 |
+
),
|
| 199 |
+
Div(
|
| 200 |
+
data_table(data_table_columns, app.state.config["providers"], "users-table"),
|
| 201 |
+
class_="p-4"
|
| 202 |
+
),
|
| 203 |
+
Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
|
| 204 |
+
class_="container mx-auto",
|
| 205 |
+
id="body"
|
| 206 |
+
)
|
| 207 |
+
).render()
|
| 208 |
+
# print(result)
|
| 209 |
+
return result
|
| 210 |
+
|
| 211 |
+
@app.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse)
|
| 212 |
+
async def get_columns_menu(menu_id: str, row_id: str):
|
| 213 |
+
columns = [
|
| 214 |
+
{
|
| 215 |
+
"label": "Edit",
|
| 216 |
+
"value": "edit",
|
| 217 |
+
"hx-get": f"/edit-sheet/{row_id}",
|
| 218 |
+
"hx-target": "#sheet-container",
|
| 219 |
+
"hx-swap": "innerHTML"
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"label": "Duplicate",
|
| 223 |
+
"value": "duplicate",
|
| 224 |
+
"hx-post": f"/duplicate/{row_id}",
|
| 225 |
+
"hx-target": "body",
|
| 226 |
+
"hx-swap": "outerHTML"
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"label": "Delete",
|
| 230 |
+
"value": "delete",
|
| 231 |
+
"hx-delete": f"/delete/{row_id}",
|
| 232 |
+
"hx-target": "body",
|
| 233 |
+
"hx-swap": "outerHTML",
|
| 234 |
+
"hx-confirm": "确定要删除这个配置吗?"
|
| 235 |
+
},
|
| 236 |
+
]
|
| 237 |
+
result = dropdown.dropdown_menu_content(menu_id, columns).render()
|
| 238 |
+
print(result)
|
| 239 |
+
return result
|
| 240 |
+
|
| 241 |
+
@app.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse)
|
| 242 |
+
async def get_columns_menu(menu_id: str):
|
| 243 |
+
result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render()
|
| 244 |
+
print(result)
|
| 245 |
+
return result
|
| 246 |
+
|
| 247 |
+
@app.get("/filter-table", response_class=HTMLResponse)
|
| 248 |
+
async def filter_table(filter: str = ""):
|
| 249 |
+
filtered_data = [
|
| 250 |
+
provider for provider in app.state.config["providers"]
|
| 251 |
+
if filter.lower() in str(provider["provider"]).lower() or
|
| 252 |
+
filter.lower() in str(provider["base_url"]).lower() or
|
| 253 |
+
filter.lower() in str(provider["tools"]).lower()
|
| 254 |
+
]
|
| 255 |
+
return data_table(data_table_columns, filtered_data, "users-table", with_filter=False).render()
|
| 256 |
+
|
| 257 |
+
@app.post("/add-model", response_class=HTMLResponse)
|
| 258 |
+
async def add_model():
|
| 259 |
+
new_model_id = f"model{hash(str(time.time()))}" # 生成一个唯一的ID
|
| 260 |
+
new_model = model_config_row(new_model_id).render()
|
| 261 |
+
return new_model
|
| 262 |
+
|
| 263 |
+
@app.get("/edit-sheet/{row_id}", response_class=HTMLResponse)
|
| 264 |
+
async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)):
|
| 265 |
+
row_data = get_row_data(row_id)
|
| 266 |
+
print("row_data", row_data)
|
| 267 |
+
|
| 268 |
+
model_list = []
|
| 269 |
+
for index, model in enumerate(row_data["model"]):
|
| 270 |
+
if isinstance(model, str):
|
| 271 |
+
model_list.append(model_config_row(f"model{index}", model, "", True))
|
| 272 |
+
if isinstance(model, dict):
|
| 273 |
+
# print("model", model, list(model.items())[0])
|
| 274 |
+
key, value = list(model.items())[0]
|
| 275 |
+
model_list.append(model_config_row(f"model{index}", key, value, True))
|
| 276 |
+
|
| 277 |
+
sheet_id = "edit-sheet"
|
| 278 |
+
edit_sheet_content = sheet.SheetContent(
|
| 279 |
+
sheet.SheetHeader(
|
| 280 |
+
sheet.SheetTitle("Edit Item"),
|
| 281 |
+
sheet.SheetDescription("Make changes to your item here.")
|
| 282 |
+
),
|
| 283 |
+
sheet.SheetBody(
|
| 284 |
+
Div(
|
| 285 |
+
form.Form(
|
| 286 |
+
form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True),
|
| 287 |
+
form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True),
|
| 288 |
+
form.FormField("API Key", "api_key", value=row_data["api"], type="text", placeholder="Enter API key"),
|
| 289 |
+
Div(
|
| 290 |
+
Div("Models", class_="text-lg font-semibold mb-2"),
|
| 291 |
+
Div(
|
| 292 |
+
*model_list,
|
| 293 |
+
id="models-container"
|
| 294 |
+
),
|
| 295 |
+
button.button(
|
| 296 |
+
"Add Model",
|
| 297 |
+
class_="mt-2",
|
| 298 |
+
hx_post="/add-model",
|
| 299 |
+
hx_target="#models-container",
|
| 300 |
+
hx_swap="beforeend"
|
| 301 |
+
),
|
| 302 |
+
class_="mb-4"
|
| 303 |
+
),
|
| 304 |
+
Div(
|
| 305 |
+
checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"),
|
| 306 |
+
class_="mb-4"
|
| 307 |
+
),
|
| 308 |
+
form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"),
|
| 309 |
+
Div(
|
| 310 |
+
button.button("Submit", variant="primary", type="submit"),
|
| 311 |
+
button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"),
|
| 312 |
+
class_="flex justify-end mt-4"
|
| 313 |
+
),
|
| 314 |
+
hx_post=f"/submit/{row_id}",
|
| 315 |
+
hx_swap="outerHTML",
|
| 316 |
+
hx_target="body",
|
| 317 |
+
class_="space-y-4"
|
| 318 |
+
),
|
| 319 |
+
class_="container mx-auto p-4 max-w-2xl"
|
| 320 |
+
)
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
result = sheet.Sheet(
|
| 325 |
+
sheet_id,
|
| 326 |
+
Div(),
|
| 327 |
+
edit_sheet_content,
|
| 328 |
+
width="80%",
|
| 329 |
+
max_width="800px"
|
| 330 |
+
).render()
|
| 331 |
+
return result
|
| 332 |
+
|
| 333 |
+
@app.get("/add-provider-sheet", response_class=HTMLResponse)
|
| 334 |
+
async def get_add_provider_sheet():
|
| 335 |
+
edit_sheet_content = sheet.SheetContent(
|
| 336 |
+
sheet.SheetHeader(
|
| 337 |
+
sheet.SheetTitle("Add New Provider"),
|
| 338 |
+
sheet.SheetDescription("Enter details for the new provider.")
|
| 339 |
+
),
|
| 340 |
+
sheet.SheetBody(
|
| 341 |
+
Div(
|
| 342 |
+
form.Form(
|
| 343 |
+
form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
|
| 344 |
+
form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
|
| 345 |
+
form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"),
|
| 346 |
+
Div(
|
| 347 |
+
Div("Models", class_="text-lg font-semibold mb-2"),
|
| 348 |
+
Div(id="models-container"),
|
| 349 |
+
button.button(
|
| 350 |
+
"Add Model",
|
| 351 |
+
class_="mt-2",
|
| 352 |
+
hx_post="/add-model",
|
| 353 |
+
hx_target="#models-container",
|
| 354 |
+
hx_swap="beforeend"
|
| 355 |
+
),
|
| 356 |
+
class_="mb-4"
|
| 357 |
+
),
|
| 358 |
+
Div(
|
| 359 |
+
checkbox.checkbox("tools", "Enable Tools", name="tools"),
|
| 360 |
+
class_="mb-4"
|
| 361 |
+
),
|
| 362 |
+
form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
|
| 363 |
+
Div(
|
| 364 |
+
button.button("Submit", variant="primary", type="submit"),
|
| 365 |
+
button.button("Cancel", variant="outline", class_="ml-2"),
|
| 366 |
+
class_="flex justify-end mt-4"
|
| 367 |
+
),
|
| 368 |
+
hx_post="/submit/new",
|
| 369 |
+
hx_swap="outerHTML",
|
| 370 |
+
hx_target="body",
|
| 371 |
+
class_="space-y-4"
|
| 372 |
+
),
|
| 373 |
+
class_="container mx-auto p-4 max-w-2xl"
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
result = sheet.Sheet(
|
| 379 |
+
"add-provider-sheet",
|
| 380 |
+
Div(),
|
| 381 |
+
edit_sheet_content,
|
| 382 |
+
width="80%",
|
| 383 |
+
max_width="800px"
|
| 384 |
+
).render()
|
| 385 |
+
return result
|
| 386 |
+
|
| 387 |
+
def get_row_data(row_id):
|
| 388 |
+
index = int(row_id)
|
| 389 |
+
# print(app.state.config["providers"])
|
| 390 |
+
return app.state.config["providers"][index]
|
| 391 |
+
|
| 392 |
+
def update_row_data(row_id, updated_data):
|
| 393 |
+
print(row_id, updated_data)
|
| 394 |
+
index = int(row_id)
|
| 395 |
+
app.state.config["providers"][index] = updated_data
|
| 396 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
| 397 |
+
yaml.dump(app.state.config, f)
|
| 398 |
+
|
| 399 |
+
@app.post("/submit/{row_id}", response_class=HTMLResponse)
|
| 400 |
+
async def submit_form(
|
| 401 |
+
row_id: str,
|
| 402 |
+
request: Request,
|
| 403 |
+
provider: str = FastapiForm(...),
|
| 404 |
+
base_url: str = FastapiForm(...),
|
| 405 |
+
api_key: Optional[str] = FastapiForm(None),
|
| 406 |
+
tools: Optional[str] = FastapiForm(None),
|
| 407 |
+
notes: Optional[str] = FastapiForm(None),
|
| 408 |
+
x_api_key: str = Depends(get_api_key)
|
| 409 |
+
):
|
| 410 |
+
form_data = await request.form()
|
| 411 |
+
|
| 412 |
+
# 收集模型数据
|
| 413 |
+
models = []
|
| 414 |
+
for key, value in form_data.items():
|
| 415 |
+
if key.startswith("model_name_"):
|
| 416 |
+
model_id = key.split("_")[-1]
|
| 417 |
+
enabled = form_data.get(f"model_enabled_{model_id}") == "on"
|
| 418 |
+
rename = form_data.get(f"model_rename_{model_id}")
|
| 419 |
+
if value:
|
| 420 |
+
if rename:
|
| 421 |
+
models.append({value: rename})
|
| 422 |
+
else:
|
| 423 |
+
models.append(value)
|
| 424 |
+
|
| 425 |
+
updated_data = {
|
| 426 |
+
"provider": provider,
|
| 427 |
+
"base_url": base_url,
|
| 428 |
+
"api": api_key,
|
| 429 |
+
"model": models,
|
| 430 |
+
"tools": tools == "on",
|
| 431 |
+
"notes": notes,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
print("updated_data", updated_data)
|
| 435 |
+
|
| 436 |
+
if row_id == "new":
|
| 437 |
+
# 添加新提供者
|
| 438 |
+
app.state.config["providers"].append(updated_data)
|
| 439 |
+
else:
|
| 440 |
+
# 更新现有提供者
|
| 441 |
+
update_row_data(row_id, updated_data)
|
| 442 |
+
|
| 443 |
+
# 保存更新后的配置
|
| 444 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
| 445 |
+
yaml.dump(app.state.config, f)
|
| 446 |
+
|
| 447 |
+
return await root()
|
| 448 |
+
|
| 449 |
+
@app.post("/duplicate/{row_id}", response_class=HTMLResponse)
|
| 450 |
+
async def duplicate_row(row_id: str):
|
| 451 |
+
index = int(row_id)
|
| 452 |
+
original_data = app.state.config["providers"][index]
|
| 453 |
+
new_data = original_data.copy()
|
| 454 |
+
new_data["provider"] += "-copy"
|
| 455 |
+
app.state.config["providers"].insert(index + 1, new_data)
|
| 456 |
+
|
| 457 |
+
# 保存更新后的配置
|
| 458 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
| 459 |
+
yaml.dump(app.state.config, f)
|
| 460 |
+
|
| 461 |
+
return await root()
|
| 462 |
+
|
| 463 |
+
@app.delete("/delete/{row_id}", response_class=HTMLResponse)
|
| 464 |
+
async def delete_row(row_id: str):
|
| 465 |
+
index = int(row_id)
|
| 466 |
+
del app.state.config["providers"][index]
|
| 467 |
+
|
| 468 |
+
# 保存更新后的配置
|
| 469 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
| 470 |
+
yaml.dump(app.state.config, f)
|
| 471 |
+
|
| 472 |
+
return await root()
|
| 473 |
+
|
| 474 |
+
if __name__ == "__main__":
|
| 475 |
+
import uvicorn
|
| 476 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
utils.py
CHANGED
|
@@ -3,26 +3,74 @@ from fastapi import HTTPException
|
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
from log_config import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def update_config(config_data):
|
| 8 |
for index, provider in enumerate(config_data['providers']):
|
| 9 |
-
model_dict = {}
|
| 10 |
-
for model in provider['model']:
|
| 11 |
-
if type(model) == str:
|
| 12 |
-
model_dict[model] = model
|
| 13 |
-
if type(model) == dict:
|
| 14 |
-
model_dict.update({new: old for old, new in model.items()})
|
| 15 |
-
provider['model'] = model_dict
|
| 16 |
if provider.get('project_id'):
|
| 17 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
| 18 |
if provider.get('cf_account_id'):
|
| 19 |
provider['base_url'] = 'https://api.cloudflare.com/'
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
config_data['providers'][index] = provider
|
| 28 |
|
|
@@ -46,31 +94,24 @@ def update_config(config_data):
|
|
| 46 |
api_keys_db[index]['model'] = models
|
| 47 |
|
| 48 |
api_list = [item["api"] for item in api_keys_db]
|
| 49 |
-
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False
|
| 50 |
return config_data, api_keys_db, api_list
|
| 51 |
|
| 52 |
# 读取YAML配置文件
|
| 53 |
async def load_config(app=None):
|
| 54 |
-
import yaml
|
| 55 |
try:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
# conf = None
|
| 69 |
-
if conf:
|
| 70 |
-
config, api_keys_db, api_list = update_config(conf)
|
| 71 |
-
else:
|
| 72 |
-
# logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
| 73 |
-
config, api_keys_db, api_list = [], [], []
|
| 74 |
except FileNotFoundError:
|
| 75 |
logger.error("'api.yaml' not found. Please check the file path.")
|
| 76 |
config, api_keys_db, api_list = [], [], []
|
|
@@ -228,7 +269,8 @@ def get_all_models(config):
|
|
| 228 |
unique_models = set()
|
| 229 |
|
| 230 |
for provider in config["providers"]:
|
| 231 |
-
|
|
|
|
| 232 |
if model not in unique_models:
|
| 233 |
unique_models.add(model)
|
| 234 |
model_info = {
|
|
@@ -260,35 +302,12 @@ def get_all_models(config):
|
|
| 260 |
# europe-west1
|
| 261 |
# europe-west4
|
| 262 |
|
| 263 |
-
def circular_list_encoder(obj):
|
| 264 |
-
if isinstance(obj, CircularList):
|
| 265 |
-
return obj.to_dict()
|
| 266 |
-
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
| 267 |
-
|
| 268 |
-
from collections import deque
|
| 269 |
-
class CircularList:
|
| 270 |
-
def __init__(self, items):
|
| 271 |
-
self.queue = deque(items)
|
| 272 |
-
|
| 273 |
-
def next(self):
|
| 274 |
-
if not self.queue:
|
| 275 |
-
return None
|
| 276 |
-
item = self.queue.popleft()
|
| 277 |
-
self.queue.append(item)
|
| 278 |
-
return item
|
| 279 |
-
|
| 280 |
-
def to_dict(self):
|
| 281 |
-
return {
|
| 282 |
-
'queue': list(self.queue)
|
| 283 |
-
}
|
| 284 |
-
|
| 285 |
-
|
| 286 |
|
| 287 |
-
c35s =
|
| 288 |
-
c3s =
|
| 289 |
-
c3o =
|
| 290 |
-
c3h =
|
| 291 |
-
gem =
|
| 292 |
|
| 293 |
class BaseAPI:
|
| 294 |
def __init__(
|
|
|
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
from log_config import logger
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
|
| 10 |
+
class ThreadSafeCircularList:
|
| 11 |
+
def __init__(self, items):
|
| 12 |
+
self.items = items
|
| 13 |
+
self.index = 0
|
| 14 |
+
self.lock = asyncio.Lock()
|
| 15 |
+
|
| 16 |
+
async def next(self):
|
| 17 |
+
async with self.lock:
|
| 18 |
+
item = self.items[self.index]
|
| 19 |
+
self.index = (self.index + 1) % len(self.items)
|
| 20 |
+
return item
|
| 21 |
+
|
| 22 |
+
def circular_list_encoder(obj):
|
| 23 |
+
if isinstance(obj, ThreadSafeCircularList):
|
| 24 |
+
return obj.to_dict()
|
| 25 |
+
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
| 26 |
+
|
| 27 |
+
provider_api_circular_list = defaultdict(ThreadSafeCircularList)
|
| 28 |
+
|
| 29 |
+
def get_model_dict(provider):
|
| 30 |
+
model_dict = {}
|
| 31 |
+
for model in provider['model']:
|
| 32 |
+
if type(model) == str:
|
| 33 |
+
model_dict[model] = model
|
| 34 |
+
if type(model) == dict:
|
| 35 |
+
model_dict.update({new: old for old, new in model.items()})
|
| 36 |
+
return model_dict
|
| 37 |
+
|
| 38 |
+
def update_initial_model(api_url, api):
|
| 39 |
+
try:
|
| 40 |
+
endpoint = BaseAPI(api_url=api_url)
|
| 41 |
+
endpoint_models_url = endpoint.v1_models
|
| 42 |
+
response = httpx.get(
|
| 43 |
+
endpoint_models_url,
|
| 44 |
+
headers={"Authorization": f"Bearer {api}"},
|
| 45 |
+
)
|
| 46 |
+
models = response.json()
|
| 47 |
+
# print(models)
|
| 48 |
+
models_list = models["data"]
|
| 49 |
+
models_id = [model["id"] for model in models_list]
|
| 50 |
+
set_models = set()
|
| 51 |
+
for model_item in models_id:
|
| 52 |
+
set_models.add(model_item)
|
| 53 |
+
models_id = list(set_models)
|
| 54 |
+
# print(models_id)
|
| 55 |
+
return models_id
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print("error:", e)
|
| 58 |
+
return []
|
| 59 |
|
| 60 |
def update_config(config_data):
|
| 61 |
for index, provider in enumerate(config_data['providers']):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if provider.get('project_id'):
|
| 63 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
| 64 |
if provider.get('cf_account_id'):
|
| 65 |
provider['base_url'] = 'https://api.cloudflare.com/'
|
| 66 |
|
| 67 |
+
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList([provider.get('api', None)])
|
| 68 |
+
|
| 69 |
+
if not provider.get("model"):
|
| 70 |
+
provider["model"] = update_initial_model(provider['base_url'], provider['api'])
|
| 71 |
+
|
| 72 |
+
if provider.get("tools") == None:
|
| 73 |
+
provider["tools"] = True
|
| 74 |
|
| 75 |
config_data['providers'][index] = provider
|
| 76 |
|
|
|
|
| 94 |
api_keys_db[index]['model'] = models
|
| 95 |
|
| 96 |
api_list = [item["api"] for item in api_keys_db]
|
| 97 |
+
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
|
| 98 |
return config_data, api_keys_db, api_list
|
| 99 |
|
| 100 |
# 读取YAML配置文件
|
| 101 |
async def load_config(app=None):
|
|
|
|
| 102 |
try:
|
| 103 |
+
from ruamel.yaml import YAML
|
| 104 |
+
yaml = YAML()
|
| 105 |
+
yaml.preserve_quotes = True
|
| 106 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
| 107 |
+
with open('api.yaml', 'r', encoding='utf-8') as file:
|
| 108 |
+
conf = yaml.load(file)
|
| 109 |
+
|
| 110 |
+
if conf:
|
| 111 |
+
config, api_keys_db, api_list = update_config(conf)
|
| 112 |
+
else:
|
| 113 |
+
# logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
| 114 |
+
config, api_keys_db, api_list = [], [], []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
except FileNotFoundError:
|
| 116 |
logger.error("'api.yaml' not found. Please check the file path.")
|
| 117 |
config, api_keys_db, api_list = [], [], []
|
|
|
|
| 269 |
unique_models = set()
|
| 270 |
|
| 271 |
for provider in config["providers"]:
|
| 272 |
+
model_dict = get_model_dict(provider)
|
| 273 |
+
for model in model_dict.keys():
|
| 274 |
if model not in unique_models:
|
| 275 |
unique_models.add(model)
|
| 276 |
model_info = {
|
|
|
|
| 302 |
# europe-west1
|
| 303 |
# europe-west4
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
+
c35s = ThreadSafeCircularList(["us-east5", "europe-west1"])
|
| 307 |
+
c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"])
|
| 308 |
+
c3o = ThreadSafeCircularList(["us-east5"])
|
| 309 |
+
c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
| 310 |
+
gem = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
| 311 |
|
| 312 |
class BaseAPI:
|
| 313 |
def __init__(
|