✨ Feature: Add support for embeddings model
Browse files
main.py
CHANGED
|
@@ -16,7 +16,7 @@ from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
|
| 16 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 17 |
from fastapi.exceptions import RequestValidationError
|
| 18 |
|
| 19 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
| 20 |
from request import get_payload
|
| 21 |
from response import fetch_response, fetch_response_stream
|
| 22 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
|
|
@@ -478,7 +478,7 @@ async def ensure_config(request: Request, call_next):
|
|
| 478 |
return await call_next(request)
|
| 479 |
|
| 480 |
# 在 process_request 函数中更新成功和失败计数
|
| 481 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
| 482 |
url = provider['base_url']
|
| 483 |
parsed_url = urlparse(url)
|
| 484 |
# print("parsed_url", parsed_url)
|
|
@@ -529,6 +529,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 529 |
engine = "moderation"
|
| 530 |
request.stream = False
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
if provider.get("engine"):
|
| 533 |
engine = provider["engine"]
|
| 534 |
|
|
@@ -700,7 +704,7 @@ class ModelRequestHandler:
|
|
| 700 |
# print("provider_list", provider_list)
|
| 701 |
return provider_list
|
| 702 |
|
| 703 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
| 704 |
config = app.state.config
|
| 705 |
api_list = app.state.api_list
|
| 706 |
api_index = api_list.index(token)
|
|
@@ -904,6 +908,13 @@ async def images_generations(
|
|
| 904 |
):
|
| 905 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 906 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 908 |
async def moderations(
|
| 909 |
request: ModerationRequest,
|
|
|
|
| 16 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 17 |
from fastapi.exceptions import RequestValidationError
|
| 18 |
|
| 19 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
|
| 20 |
from request import get_payload
|
| 21 |
from response import fetch_response, fetch_response_stream
|
| 22 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
|
|
|
|
| 478 |
return await call_next(request)
|
| 479 |
|
| 480 |
# 在 process_request 函数中更新成功和失败计数
|
| 481 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None):
|
| 482 |
url = provider['base_url']
|
| 483 |
parsed_url = urlparse(url)
|
| 484 |
# print("parsed_url", parsed_url)
|
|
|
|
| 529 |
engine = "moderation"
|
| 530 |
request.stream = False
|
| 531 |
|
| 532 |
+
if endpoint == "/v1/embeddings":
|
| 533 |
+
engine = "embedding"
|
| 534 |
+
request.stream = False
|
| 535 |
+
|
| 536 |
if provider.get("engine"):
|
| 537 |
engine = provider["engine"]
|
| 538 |
|
|
|
|
| 704 |
# print("provider_list", provider_list)
|
| 705 |
return provider_list
|
| 706 |
|
| 707 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
| 708 |
config = app.state.config
|
| 709 |
api_list = app.state.api_list
|
| 710 |
api_index = api_list.index(token)
|
|
|
|
| 908 |
):
|
| 909 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 910 |
|
| 911 |
+
@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
|
| 912 |
+
async def embeddings(
|
| 913 |
+
request: EmbeddingRequest,
|
| 914 |
+
token: str = Depends(verify_api_key)
|
| 915 |
+
):
|
| 916 |
+
return await model_handler.request_model(request, token, endpoint="/v1/embeddings")
|
| 917 |
+
|
| 918 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 919 |
async def moderations(
|
| 920 |
request: ModerationRequest,
|
models.py
CHANGED
|
@@ -111,6 +111,12 @@ class ImageGenerationRequest(BaseRequest):
|
|
| 111 |
size: Optional[str] = "1024x1024"
|
| 112 |
stream: bool = False
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
class AudioTranscriptionRequest(BaseRequest):
|
| 115 |
file: Tuple[str, IOBase, str]
|
| 116 |
model: str
|
|
@@ -129,7 +135,7 @@ class ModerationRequest(BaseRequest):
|
|
| 129 |
stream: bool = False
|
| 130 |
|
| 131 |
class UnifiedRequest(BaseModel):
|
| 132 |
-
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest]
|
| 133 |
|
| 134 |
@model_validator(mode='before')
|
| 135 |
@classmethod
|
|
@@ -147,6 +153,9 @@ class UnifiedRequest(BaseModel):
|
|
| 147 |
elif "input" in values:
|
| 148 |
values["data"] = ModerationRequest(**values)
|
| 149 |
values["data"].request_type = "moderation"
|
|
|
|
|
|
|
|
|
|
| 150 |
else:
|
| 151 |
raise ValueError("无法确定请求类型")
|
| 152 |
return values
|
|
|
|
| 111 |
size: Optional[str] = "1024x1024"
|
| 112 |
stream: bool = False
|
| 113 |
|
| 114 |
+
class EmbeddingRequest(BaseRequest):
|
| 115 |
+
input: str
|
| 116 |
+
model: str
|
| 117 |
+
encoding_format: Optional[str] = "float"
|
| 118 |
+
stream: bool = False
|
| 119 |
+
|
| 120 |
class AudioTranscriptionRequest(BaseRequest):
|
| 121 |
file: Tuple[str, IOBase, str]
|
| 122 |
model: str
|
|
|
|
| 135 |
stream: bool = False
|
| 136 |
|
| 137 |
class UnifiedRequest(BaseModel):
|
| 138 |
+
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]
|
| 139 |
|
| 140 |
@model_validator(mode='before')
|
| 141 |
@classmethod
|
|
|
|
| 153 |
elif "input" in values:
|
| 154 |
values["data"] = ModerationRequest(**values)
|
| 155 |
values["data"].request_type = "moderation"
|
| 156 |
+
elif "input" in values:
|
| 157 |
+
values["data"] = EmbeddingRequest(**values)
|
| 158 |
+
values["data"].request_type = "embedding"
|
| 159 |
else:
|
| 160 |
raise ValueError("无法确定请求类型")
|
| 161 |
return values
|
request.py
CHANGED
|
@@ -1125,6 +1125,27 @@ async def get_moderation_payload(request, engine, provider):
|
|
| 1125 |
|
| 1126 |
return url, headers, payload
|
| 1127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1128 |
async def get_payload(request: RequestModel, engine, provider):
|
| 1129 |
if engine == "gemini":
|
| 1130 |
return await get_gemini_payload(request, engine, provider)
|
|
@@ -1150,5 +1171,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
| 1150 |
return await get_whisper_payload(request, engine, provider)
|
| 1151 |
elif engine == "moderation":
|
| 1152 |
return await get_moderation_payload(request, engine, provider)
|
|
|
|
|
|
|
| 1153 |
else:
|
| 1154 |
raise ValueError("Unknown payload")
|
|
|
|
| 1125 |
|
| 1126 |
return url, headers, payload
|
| 1127 |
|
| 1128 |
+
async def get_embedding_payload(request, engine, provider):
|
| 1129 |
+
model_dict = get_model_dict(provider)
|
| 1130 |
+
model = model_dict[request.model]
|
| 1131 |
+
headers = {
|
| 1132 |
+
"Content-Type": "application/json",
|
| 1133 |
+
}
|
| 1134 |
+
if provider.get("api"):
|
| 1135 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
| 1136 |
+
url = provider['base_url']
|
| 1137 |
+
url = BaseAPI(url).embeddings
|
| 1138 |
+
|
| 1139 |
+
payload = {
|
| 1140 |
+
"input": request.input,
|
| 1141 |
+
"model": model,
|
| 1142 |
+
}
|
| 1143 |
+
|
| 1144 |
+
if request.encoding_format:
|
| 1145 |
+
payload["encoding_format"] = request.encoding_format
|
| 1146 |
+
|
| 1147 |
+
return url, headers, payload
|
| 1148 |
+
|
| 1149 |
async def get_payload(request: RequestModel, engine, provider):
|
| 1150 |
if engine == "gemini":
|
| 1151 |
return await get_gemini_payload(request, engine, provider)
|
|
|
|
| 1171 |
return await get_whisper_payload(request, engine, provider)
|
| 1172 |
elif engine == "moderation":
|
| 1173 |
return await get_moderation_payload(request, engine, provider)
|
| 1174 |
+
elif engine == "embedding":
|
| 1175 |
+
return await get_embedding_payload(request, engine, provider)
|
| 1176 |
else:
|
| 1177 |
raise ValueError("Unknown payload")
|
utils.py
CHANGED
|
@@ -377,6 +377,7 @@ class BaseAPI:
|
|
| 377 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 378 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 379 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
|
|
|
| 380 |
|
| 381 |
def safe_get(data, *keys, default=None):
|
| 382 |
for key in keys:
|
|
|
|
| 377 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 378 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 379 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
| 380 |
+
self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/embeddings",) + ("",) * 3)
|
| 381 |
|
| 382 |
def safe_get(data, *keys, default=None):
|
| 383 |
for key in keys:
|