✨ Feature: Add support for v1/moderations endpoint.
Browse files- main.py +16 -5
- models.py +5 -0
- request.py +19 -1
- response.py +0 -1
- utils.py +1 -0
main.py
CHANGED
|
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
|
| 12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
from fastapi.exceptions import RequestValidationError
|
| 14 |
|
| 15 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest
|
| 16 |
from request import get_payload
|
| 17 |
from response import fetch_response, fetch_response_stream
|
| 18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
@@ -191,7 +191,7 @@ app.add_middleware(
|
|
| 191 |
app.add_middleware(StatsMiddleware)
|
| 192 |
|
| 193 |
# 在 process_request 函数中更新成功和失败计数
|
| 194 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], provider: Dict, endpoint=None, token=None):
|
| 195 |
url = provider['base_url']
|
| 196 |
parsed_url = urlparse(url)
|
| 197 |
# print("parsed_url", parsed_url)
|
|
@@ -237,6 +237,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 237 |
engine = "whisper"
|
| 238 |
request.stream = False
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
if provider.get("engine"):
|
| 241 |
engine = provider["engine"]
|
| 242 |
|
|
@@ -363,7 +367,7 @@ class ModelRequestHandler:
|
|
| 363 |
print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
| 364 |
return provider_list
|
| 365 |
|
| 366 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str, endpoint=None):
|
| 367 |
config = app.state.config
|
| 368 |
# api_keys_db = app.state.api_keys_db
|
| 369 |
api_list = app.state.api_list
|
|
@@ -406,7 +410,7 @@ class ModelRequestHandler:
|
|
| 406 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
| 407 |
|
| 408 |
# 在 try_all_providers 函数中处理失败的情况
|
| 409 |
-
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
| 410 |
status_code = 500
|
| 411 |
error_message = None
|
| 412 |
num_providers = len(providers)
|
|
@@ -533,7 +537,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
|
|
| 533 |
return token
|
| 534 |
|
| 535 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
| 536 |
-
async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str = Depends(verify_api_key)):
|
| 537 |
# logger.info(f"Request received: {request}")
|
| 538 |
return await model_handler.request_model(request, token)
|
| 539 |
|
|
@@ -556,6 +560,13 @@ async def images_generations(
|
|
| 556 |
):
|
| 557 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
from fastapi import UploadFile, File, Form, HTTPException
|
| 560 |
import io
|
| 561 |
@app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
|
|
|
|
| 12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
from fastapi.exceptions import RequestValidationError
|
| 14 |
|
| 15 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest
|
| 16 |
from request import get_payload
|
| 17 |
from response import fetch_response, fetch_response_stream
|
| 18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
|
|
| 191 |
app.add_middleware(StatsMiddleware)
|
| 192 |
|
| 193 |
# 在 process_request 函数中更新成功和失败计数
|
| 194 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
| 195 |
url = provider['base_url']
|
| 196 |
parsed_url = urlparse(url)
|
| 197 |
# print("parsed_url", parsed_url)
|
|
|
|
| 237 |
engine = "whisper"
|
| 238 |
request.stream = False
|
| 239 |
|
| 240 |
+
if endpoint == "/v1/moderations":
|
| 241 |
+
engine = "moderation"
|
| 242 |
+
request.stream = False
|
| 243 |
+
|
| 244 |
if provider.get("engine"):
|
| 245 |
engine = provider["engine"]
|
| 246 |
|
|
|
|
| 367 |
print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
| 368 |
return provider_list
|
| 369 |
|
| 370 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
| 371 |
config = app.state.config
|
| 372 |
# api_keys_db = app.state.api_keys_db
|
| 373 |
api_list = app.state.api_list
|
|
|
|
| 410 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
| 411 |
|
| 412 |
# 在 try_all_providers 函数中处理失败的情况
|
| 413 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
| 414 |
status_code = 500
|
| 415 |
error_message = None
|
| 416 |
num_providers = len(providers)
|
|
|
|
| 537 |
return token
|
| 538 |
|
| 539 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
| 540 |
+
async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str = Depends(verify_api_key)):
|
| 541 |
# logger.info(f"Request received: {request}")
|
| 542 |
return await model_handler.request_model(request, token)
|
| 543 |
|
|
|
|
| 560 |
):
|
| 561 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 562 |
|
| 563 |
+
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 564 |
+
async def images_generations(
|
| 565 |
+
request: ModerationRequest,
|
| 566 |
+
token: str = Depends(verify_api_key)
|
| 567 |
+
):
|
| 568 |
+
return await model_handler.request_model(request, token, endpoint="/v1/moderations")
|
| 569 |
+
|
| 570 |
from fastapi import UploadFile, File, Form, HTTPException
|
| 571 |
import io
|
| 572 |
@app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
|
models.py
CHANGED
|
@@ -21,6 +21,11 @@ class AudioTranscriptionRequest(BaseModel):
|
|
| 21 |
class Config:
|
| 22 |
arbitrary_types_allowed = True
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
class FunctionParameter(BaseModel):
|
| 25 |
type: str
|
| 26 |
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
|
|
|
| 21 |
class Config:
|
| 22 |
arbitrary_types_allowed = True
|
| 23 |
|
| 24 |
+
class ModerationRequest(BaseModel):
|
| 25 |
+
input: str
|
| 26 |
+
model: Optional[str] = "text-moderation-latest"
|
| 27 |
+
stream: bool = False
|
| 28 |
+
|
| 29 |
class FunctionParameter(BaseModel):
|
| 30 |
type: str
|
| 31 |
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
request.py
CHANGED
|
@@ -1043,7 +1043,7 @@ async def get_dalle_payload(request, engine, provider):
|
|
| 1043 |
async def get_whisper_payload(request, engine, provider):
|
| 1044 |
model = provider['model'][request.model]
|
| 1045 |
headers = {
|
| 1046 |
-
"Content-Type": "
|
| 1047 |
}
|
| 1048 |
if provider.get("api"):
|
| 1049 |
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
|
@@ -1066,6 +1066,22 @@ async def get_whisper_payload(request, engine, provider):
|
|
| 1066 |
|
| 1067 |
return url, headers, payload
|
| 1068 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
async def get_payload(request: RequestModel, engine, provider):
|
| 1070 |
if engine == "gemini":
|
| 1071 |
return await get_gemini_payload(request, engine, provider)
|
|
@@ -1089,5 +1105,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
| 1089 |
return await get_dalle_payload(request, engine, provider)
|
| 1090 |
elif engine == "whisper":
|
| 1091 |
return await get_whisper_payload(request, engine, provider)
|
|
|
|
|
|
|
| 1092 |
else:
|
| 1093 |
raise ValueError("Unknown payload")
|
|
|
|
| 1043 |
async def get_whisper_payload(request, engine, provider):
|
| 1044 |
model = provider['model'][request.model]
|
| 1045 |
headers = {
|
| 1046 |
+
"Content-Type": "multipart/form-data",
|
| 1047 |
}
|
| 1048 |
if provider.get("api"):
|
| 1049 |
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
|
|
|
| 1066 |
|
| 1067 |
return url, headers, payload
|
| 1068 |
|
| 1069 |
+
async def get_moderation_payload(request, engine, provider):
|
| 1070 |
+
model = provider['model'][request.model]
|
| 1071 |
+
headers = {
|
| 1072 |
+
"Content-Type": "application/json",
|
| 1073 |
+
}
|
| 1074 |
+
if provider.get("api"):
|
| 1075 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
| 1076 |
+
url = provider['base_url']
|
| 1077 |
+
url = BaseAPI(url).moderations
|
| 1078 |
+
|
| 1079 |
+
payload = {
|
| 1080 |
+
"input": request.input,
|
| 1081 |
+
}
|
| 1082 |
+
|
| 1083 |
+
return url, headers, payload
|
| 1084 |
+
|
| 1085 |
async def get_payload(request: RequestModel, engine, provider):
|
| 1086 |
if engine == "gemini":
|
| 1087 |
return await get_gemini_payload(request, engine, provider)
|
|
|
|
| 1105 |
return await get_dalle_payload(request, engine, provider)
|
| 1106 |
elif engine == "whisper":
|
| 1107 |
return await get_whisper_payload(request, engine, provider)
|
| 1108 |
+
elif engine == "moderation":
|
| 1109 |
+
return await get_moderation_payload(request, engine, provider)
|
| 1110 |
else:
|
| 1111 |
raise ValueError("Unknown payload")
|
response.py
CHANGED
|
@@ -273,7 +273,6 @@ async def fetch_response(client, url, headers, payload):
|
|
| 273 |
response = None
|
| 274 |
if payload.get("file"):
|
| 275 |
file = payload.pop("file")
|
| 276 |
-
headers.pop("Content-Type")
|
| 277 |
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
| 278 |
else:
|
| 279 |
response = await client.post(url, headers=headers, json=payload)
|
|
|
|
| 273 |
response = None
|
| 274 |
if payload.get("file"):
|
| 275 |
file = payload.pop("file")
|
|
|
|
| 276 |
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
| 277 |
else:
|
| 278 |
response = await client.post(url, headers=headers, json=payload)
|
utils.py
CHANGED
|
@@ -308,6 +308,7 @@ class BaseAPI:
|
|
| 308 |
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
| 309 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 310 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
|
|
|
| 311 |
|
| 312 |
def safe_get(data, *keys):
|
| 313 |
for key in keys:
|
|
|
|
| 308 |
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
| 309 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
| 310 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 311 |
+
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
| 312 |
|
| 313 |
def safe_get(data, *keys):
|
| 314 |
for key in keys:
|