🐛 Bug: 1. Fix the bug that causes an error when Claude uploads a PNG image.
Browse files2. Fix the bug where fields are not automatically added when the database does not have specific fields.
✨ Feature: Add user message ethics review support.
📖 Docs: Update documentation
- README_CN.md +2 -0
- main.py +100 -13
- models.py +12 -1
- request.py +36 -7
- utils.py +2 -2
README_CN.md
CHANGED
|
@@ -105,10 +105,12 @@ api_keys:
|
|
| 105 |
model:
|
| 106 |
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。这种写法不会匹配到other-provider提供的名为anthropic/claude-3-5-sonnet的模型。
|
| 107 |
- <anthropic/claude-3-5-sonnet> # 通过在模型名两侧加上尖括号,这样就不会去名为anthropic的渠道下去寻找claude-3-5-sonnet模型,而是将整个 anthropic/claude-3-5-sonnet 作为模型名称。这种写法可以匹配到other-provider提供的名为 anthropic/claude-3-5-sonnet 的模型。但不会匹配到anthropic下面的claude-3-5-sonnet模型。
|
|
|
|
| 108 |
preferences:
|
| 109 |
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
| 110 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
| 111 |
RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填
|
|
|
|
| 112 |
|
| 113 |
# 渠道级加权负载均衡配置示例
|
| 114 |
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
|
|
|
|
| 105 |
model:
|
| 106 |
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。这种写法不会匹配到other-provider提供的名为anthropic/claude-3-5-sonnet的模型。
|
| 107 |
- <anthropic/claude-3-5-sonnet> # 通过在模型名两侧加上尖括号,这样就不会去名为anthropic的渠道下去寻找claude-3-5-sonnet模型,而是将整个 anthropic/claude-3-5-sonnet 作为模型名称。这种写法可以匹配到other-provider提供的名为 anthropic/claude-3-5-sonnet 的模型。但不会匹配到anthropic下面的claude-3-5-sonnet模型。
|
| 108 |
+
- openai-test/text-moderation-latest # 当开启消息道德审查后,可以使用名为 openai-test 渠道下的 text-moderation-latest 模型进行道德审查。
|
| 109 |
preferences:
|
| 110 |
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
| 111 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
| 112 |
RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填
|
| 113 |
+
ENABLE_MODERATION: true # 是否开启消息道德审查,true 为开启,false 为不开启,默认为 false,当开启后,会对用户的消息进行道德审查,如果发现不当的消息,会返回错误信息。
|
| 114 |
|
| 115 |
# 渠道级加权负载均衡配置示例
|
| 116 |
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
|
main.py
CHANGED
|
@@ -24,10 +24,50 @@ from urllib.parse import urlparse
|
|
| 24 |
import os
|
| 25 |
is_debug = bool(os.getenv("DEBUG", False))
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
async def create_tables():
|
| 28 |
async with engine.begin() as conn:
|
| 29 |
await conn.run_sync(Base.metadata.create_all)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
@asynccontextmanager
|
| 32 |
async def lifespan(app: FastAPI):
|
| 33 |
# 启动时的代码
|
|
@@ -79,7 +119,7 @@ async def parse_request_body(request: Request):
|
|
| 79 |
|
| 80 |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
| 81 |
from sqlalchemy.orm import declarative_base, sessionmaker
|
| 82 |
-
from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean
|
| 83 |
from sqlalchemy.sql import func
|
| 84 |
|
| 85 |
# 定义数据库模型
|
|
@@ -93,6 +133,8 @@ class RequestStat(Base):
|
|
| 93 |
token = Column(String)
|
| 94 |
total_time = Column(Float)
|
| 95 |
model = Column(String)
|
|
|
|
|
|
|
| 96 |
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
| 97 |
|
| 98 |
class ChannelStat(Base):
|
|
@@ -113,6 +155,7 @@ data_dir = os.path.dirname(db_path)
|
|
| 113 |
os.makedirs(data_dir, exist_ok=True)
|
| 114 |
|
| 115 |
# 创建异步引擎和会话
|
|
|
|
| 116 |
engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=is_debug)
|
| 117 |
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
| 118 |
|
|
@@ -132,37 +175,76 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 132 |
start_time = time()
|
| 133 |
|
| 134 |
request.state.parsed_body = await parse_request_body(request)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
model = "unknown"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
if request.state.parsed_body:
|
| 138 |
try:
|
| 139 |
request_model = RequestModel(**request.state.parsed_body)
|
| 140 |
model = request_model.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
except RequestValidationError:
|
| 142 |
pass
|
| 143 |
except Exception as e:
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
response = await call_next(request)
|
| 147 |
process_time = time() - start_time
|
| 148 |
|
| 149 |
-
endpoint = f"{request.method} {request.url.path}"
|
| 150 |
-
client_ip = request.client.host
|
| 151 |
-
|
| 152 |
# 异步更新数据库
|
| 153 |
-
await self.update_stats(endpoint, process_time, client_ip, model, token)
|
| 154 |
|
| 155 |
return response
|
| 156 |
|
| 157 |
-
async def update_stats(self, endpoint, process_time, client_ip, model, token):
|
| 158 |
async with self.db as session:
|
| 159 |
-
# 为每个请求创建一条新的记录
|
| 160 |
new_request_stat = RequestStat(
|
| 161 |
endpoint=endpoint,
|
| 162 |
ip=client_ip,
|
| 163 |
token=token,
|
| 164 |
total_time=process_time,
|
| 165 |
-
model=model
|
|
|
|
|
|
|
| 166 |
)
|
| 167 |
session.add(new_request_stat)
|
| 168 |
await session.commit()
|
|
@@ -179,6 +261,14 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 179 |
session.add(channel_stat)
|
| 180 |
await session.commit()
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# 配置 CORS 中间件
|
| 183 |
app.add_middleware(
|
| 184 |
CORSMiddleware,
|
|
@@ -561,7 +651,7 @@ async def images_generations(
|
|
| 561 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 562 |
|
| 563 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 564 |
-
async def
|
| 565 |
request: ModerationRequest,
|
| 566 |
token: str = Depends(verify_api_key)
|
| 567 |
):
|
|
@@ -601,9 +691,6 @@ def generate_api_key():
|
|
| 601 |
return JSONResponse(content={"api_key": api_key})
|
| 602 |
|
| 603 |
# 在 /stats 路由中返回成功和失败百分比
|
| 604 |
-
from collections import defaultdict
|
| 605 |
-
from sqlalchemy import func
|
| 606 |
-
|
| 607 |
from collections import defaultdict
|
| 608 |
from sqlalchemy import func, desc, case
|
| 609 |
|
|
|
|
| 24 |
import os
|
| 25 |
is_debug = bool(os.getenv("DEBUG", False))
|
| 26 |
|
| 27 |
+
from sqlalchemy import inspect, text
|
| 28 |
+
from sqlalchemy.sql import sqltypes
|
| 29 |
+
|
| 30 |
async def create_tables():
|
| 31 |
async with engine.begin() as conn:
|
| 32 |
await conn.run_sync(Base.metadata.create_all)
|
| 33 |
|
| 34 |
+
# 检查并添加缺失的列
|
| 35 |
+
def check_and_add_columns(connection):
|
| 36 |
+
inspector = inspect(connection)
|
| 37 |
+
for table in [RequestStat, ChannelStat]:
|
| 38 |
+
table_name = table.__tablename__
|
| 39 |
+
existing_columns = {col['name']: col['type'] for col in inspector.get_columns(table_name)}
|
| 40 |
+
|
| 41 |
+
for column_name, column in table.__table__.columns.items():
|
| 42 |
+
if column_name not in existing_columns:
|
| 43 |
+
col_type = _map_sa_type_to_sql_type(column.type)
|
| 44 |
+
default = _get_default_sql(column.default)
|
| 45 |
+
connection.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {col_type}{default}"))
|
| 46 |
+
|
| 47 |
+
await conn.run_sync(check_and_add_columns)
|
| 48 |
+
|
| 49 |
+
def _map_sa_type_to_sql_type(sa_type):
|
| 50 |
+
type_map = {
|
| 51 |
+
sqltypes.Integer: "INTEGER",
|
| 52 |
+
sqltypes.String: "TEXT",
|
| 53 |
+
sqltypes.Float: "REAL",
|
| 54 |
+
sqltypes.Boolean: "BOOLEAN",
|
| 55 |
+
sqltypes.DateTime: "DATETIME",
|
| 56 |
+
sqltypes.Text: "TEXT"
|
| 57 |
+
}
|
| 58 |
+
return type_map.get(type(sa_type), "TEXT")
|
| 59 |
+
|
| 60 |
+
def _get_default_sql(default):
|
| 61 |
+
if default is None:
|
| 62 |
+
return ""
|
| 63 |
+
if isinstance(default.arg, bool):
|
| 64 |
+
return f" DEFAULT {str(default.arg).upper()}"
|
| 65 |
+
if isinstance(default.arg, (int, float)):
|
| 66 |
+
return f" DEFAULT {default.arg}"
|
| 67 |
+
if isinstance(default.arg, str):
|
| 68 |
+
return f" DEFAULT '{default.arg}'"
|
| 69 |
+
return ""
|
| 70 |
+
|
| 71 |
@asynccontextmanager
|
| 72 |
async def lifespan(app: FastAPI):
|
| 73 |
# 启动时的代码
|
|
|
|
| 119 |
|
| 120 |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
| 121 |
from sqlalchemy.orm import declarative_base, sessionmaker
|
| 122 |
+
from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean, Text
|
| 123 |
from sqlalchemy.sql import func
|
| 124 |
|
| 125 |
# 定义数据库模型
|
|
|
|
| 133 |
token = Column(String)
|
| 134 |
total_time = Column(Float)
|
| 135 |
model = Column(String)
|
| 136 |
+
is_flagged = Column(Boolean, default=False)
|
| 137 |
+
moderated_content = Column(Text)
|
| 138 |
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
| 139 |
|
| 140 |
class ChannelStat(Base):
|
|
|
|
| 155 |
os.makedirs(data_dir, exist_ok=True)
|
| 156 |
|
| 157 |
# 创建异步引擎和会话
|
| 158 |
+
# engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=False)
|
| 159 |
engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=is_debug)
|
| 160 |
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
| 161 |
|
|
|
|
| 175 |
start_time = time()
|
| 176 |
|
| 177 |
request.state.parsed_body = await parse_request_body(request)
|
| 178 |
+
endpoint = f"{request.method} {request.url.path}"
|
| 179 |
+
client_ip = request.client.host
|
| 180 |
|
| 181 |
model = "unknown"
|
| 182 |
+
enable_moderation = False # 默认不开启道德审查
|
| 183 |
+
is_flagged = False
|
| 184 |
+
moderated_content = ""
|
| 185 |
+
|
| 186 |
+
config = app.state.config
|
| 187 |
+
api_list = app.state.api_list
|
| 188 |
+
|
| 189 |
+
# 根据token决定是否启用道德审查
|
| 190 |
+
if token:
|
| 191 |
+
try:
|
| 192 |
+
api_index = api_list.index(token)
|
| 193 |
+
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
| 194 |
+
except ValueError:
|
| 195 |
+
# token不在api_list中,使用默认值(不开启)
|
| 196 |
+
pass
|
| 197 |
+
else:
|
| 198 |
+
# 如果token为None,检查全局设置
|
| 199 |
+
enable_moderation = config.get('ENABLE_MODERATION', False)
|
| 200 |
+
|
| 201 |
if request.state.parsed_body:
|
| 202 |
try:
|
| 203 |
request_model = RequestModel(**request.state.parsed_body)
|
| 204 |
model = request_model.model
|
| 205 |
+
moderated_content = request_model.get_last_text_message()
|
| 206 |
+
|
| 207 |
+
if enable_moderation and moderated_content:
|
| 208 |
+
moderation_response = await self.moderate_content(moderated_content, token)
|
| 209 |
+
moderation_result = moderation_response.body
|
| 210 |
+
moderation_data = json.loads(moderation_result)
|
| 211 |
+
is_flagged = moderation_data.get('results', [{}])[0].get('flagged', False)
|
| 212 |
+
|
| 213 |
+
if is_flagged:
|
| 214 |
+
logger.error(f"Content did not pass the moral check: %s", moderated_content)
|
| 215 |
+
process_time = time() - start_time
|
| 216 |
+
await self.update_stats(endpoint, process_time, client_ip, model, token, is_flagged, moderated_content)
|
| 217 |
+
return JSONResponse(
|
| 218 |
+
status_code=400,
|
| 219 |
+
content={"error": "Content did not pass the moral check, please modify and try again."}
|
| 220 |
+
)
|
| 221 |
except RequestValidationError:
|
| 222 |
pass
|
| 223 |
except Exception as e:
|
| 224 |
+
if is_debug:
|
| 225 |
+
import traceback
|
| 226 |
+
traceback.print_exc()
|
| 227 |
+
|
| 228 |
+
logger.error(f"处理请求或进行道德检查时出错: {str(e)}")
|
| 229 |
|
| 230 |
response = await call_next(request)
|
| 231 |
process_time = time() - start_time
|
| 232 |
|
|
|
|
|
|
|
|
|
|
| 233 |
# 异步更新数据库
|
| 234 |
+
await self.update_stats(endpoint, process_time, client_ip, model, token, is_flagged, moderated_content)
|
| 235 |
|
| 236 |
return response
|
| 237 |
|
| 238 |
+
async def update_stats(self, endpoint, process_time, client_ip, model, token, is_flagged, moderated_content):
|
| 239 |
async with self.db as session:
|
|
|
|
| 240 |
new_request_stat = RequestStat(
|
| 241 |
endpoint=endpoint,
|
| 242 |
ip=client_ip,
|
| 243 |
token=token,
|
| 244 |
total_time=process_time,
|
| 245 |
+
model=model,
|
| 246 |
+
is_flagged=is_flagged,
|
| 247 |
+
moderated_content=moderated_content
|
| 248 |
)
|
| 249 |
session.add(new_request_stat)
|
| 250 |
await session.commit()
|
|
|
|
| 261 |
session.add(channel_stat)
|
| 262 |
await session.commit()
|
| 263 |
|
| 264 |
+
async def moderate_content(self, content, token):
|
| 265 |
+
moderation_request = ModerationRequest(input=content)
|
| 266 |
+
|
| 267 |
+
# 直接调用 moderations 函数
|
| 268 |
+
response = await moderations(moderation_request, token)
|
| 269 |
+
|
| 270 |
+
return response
|
| 271 |
+
|
| 272 |
# 配置 CORS 中间件
|
| 273 |
app.add_middleware(
|
| 274 |
CORSMiddleware,
|
|
|
|
| 651 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
| 652 |
|
| 653 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
| 654 |
+
async def moderations(
|
| 655 |
request: ModerationRequest,
|
| 656 |
token: str = Depends(verify_api_key)
|
| 657 |
):
|
|
|
|
| 691 |
return JSONResponse(content={"api_key": api_key})
|
| 692 |
|
| 693 |
# 在 /stats 路由中返回成功和失败百分比
|
|
|
|
|
|
|
|
|
|
| 694 |
from collections import defaultdict
|
| 695 |
from sqlalchemy import func, desc, case
|
| 696 |
|
models.py
CHANGED
|
@@ -96,4 +96,15 @@ class RequestModel(BaseModel):
|
|
| 96 |
n: Optional[int] = 1
|
| 97 |
user: Optional[str] = None
|
| 98 |
tool_choice: Optional[Union[str, ToolChoice]] = None
|
| 99 |
-
tools: Optional[List[Tool]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
n: Optional[int] = 1
|
| 97 |
user: Optional[str] = None
|
| 98 |
tool_choice: Optional[Union[str, ToolChoice]] = None
|
| 99 |
+
tools: Optional[List[Tool]] = None
|
| 100 |
+
|
| 101 |
+
def get_last_text_message(self) -> Optional[str]:
|
| 102 |
+
for message in reversed(self.messages):
|
| 103 |
+
if message.content:
|
| 104 |
+
if isinstance(message.content, str):
|
| 105 |
+
return message.content
|
| 106 |
+
elif isinstance(message.content, list):
|
| 107 |
+
for item in reversed(message.content):
|
| 108 |
+
if item.type == "text" and item.text:
|
| 109 |
+
return item.text
|
| 110 |
+
return ""
|
request.py
CHANGED
|
@@ -8,9 +8,20 @@ import urllib.parse
|
|
| 8 |
from models import RequestModel
|
| 9 |
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
|
| 10 |
|
|
|
|
|
|
|
| 11 |
def encode_image(image_path):
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
async def get_doc_from_url(url):
|
| 16 |
filename = urllib.parse.unquote(url.split("/")[-1])
|
|
@@ -37,12 +48,28 @@ async def get_encode_image(image_url):
|
|
| 37 |
filename = await get_doc_from_url(image_url)
|
| 38 |
image_path = os.getcwd() + "/" + filename
|
| 39 |
base64_image = encode_image(image_path)
|
| 40 |
-
if filename.endswith(".png"):
|
| 41 |
-
prompt = f"data:image/png;base64,{base64_image}"
|
| 42 |
-
else:
|
| 43 |
-
prompt = f"data:image/jpeg;base64,{base64_image}"
|
| 44 |
os.remove(image_path)
|
| 45 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
async def get_image_message(base64_image, engine = None):
|
| 48 |
if base64_image.startswith("http"):
|
|
@@ -59,6 +86,8 @@ async def get_image_message(base64_image, engine = None):
|
|
| 59 |
}
|
| 60 |
}
|
| 61 |
if "claude" == engine or "vertex-claude" == engine:
|
|
|
|
|
|
|
| 62 |
return {
|
| 63 |
"type": "image",
|
| 64 |
"source": {
|
|
|
|
| 8 |
from models import RequestModel
|
| 9 |
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
|
| 10 |
|
| 11 |
+
import imghdr
|
| 12 |
+
|
| 13 |
def encode_image(image_path):
|
| 14 |
+
with open(image_path, "rb") as image_file:
|
| 15 |
+
file_content = image_file.read()
|
| 16 |
+
file_type = imghdr.what(None, file_content)
|
| 17 |
+
base64_encoded = base64.b64encode(file_content).decode('utf-8')
|
| 18 |
+
|
| 19 |
+
if file_type == 'png':
|
| 20 |
+
return f"data:image/png;base64,{base64_encoded}"
|
| 21 |
+
elif file_type in ['jpeg', 'jpg']:
|
| 22 |
+
return f"data:image/jpeg;base64,{base64_encoded}"
|
| 23 |
+
else:
|
| 24 |
+
raise ValueError(f"不支持的图片格式: {file_type}")
|
| 25 |
|
| 26 |
async def get_doc_from_url(url):
|
| 27 |
filename = urllib.parse.unquote(url.split("/")[-1])
|
|
|
|
| 48 |
filename = await get_doc_from_url(image_url)
|
| 49 |
image_path = os.getcwd() + "/" + filename
|
| 50 |
base64_image = encode_image(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
os.remove(image_path)
|
| 52 |
+
return base64_image
|
| 53 |
+
|
| 54 |
+
from PIL import Image
|
| 55 |
+
import io
|
| 56 |
+
def validate_image(image_data, image_type):
|
| 57 |
+
try:
|
| 58 |
+
decoded_image = base64.b64decode(image_data)
|
| 59 |
+
image = Image.open(io.BytesIO(decoded_image))
|
| 60 |
+
|
| 61 |
+
# 检查图片格式是否与声明的类型匹配
|
| 62 |
+
# print("image.format", image.format)
|
| 63 |
+
if image_type == "image/png" and image.format != "PNG":
|
| 64 |
+
raise ValueError("Image is not a valid PNG")
|
| 65 |
+
elif image_type == "image/jpeg" and image.format not in ["JPEG", "JPG"]:
|
| 66 |
+
raise ValueError("Image is not a valid JPEG")
|
| 67 |
+
|
| 68 |
+
# 如果没有异常,则图片有效
|
| 69 |
+
return True
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Image validation failed: {str(e)}")
|
| 72 |
+
return False
|
| 73 |
|
| 74 |
async def get_image_message(base64_image, engine = None):
|
| 75 |
if base64_image.startswith("http"):
|
|
|
|
| 86 |
}
|
| 87 |
}
|
| 88 |
if "claude" == engine or "vertex-claude" == engine:
|
| 89 |
+
# if not validate_image(base64_image.split(",")[1], image_type):
|
| 90 |
+
# raise ValueError(f"Invalid image format. Expected {image_type}")
|
| 91 |
return {
|
| 92 |
"type": "image",
|
| 93 |
"source": {
|
utils.py
CHANGED
|
@@ -310,10 +310,10 @@ class BaseAPI:
|
|
| 310 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 311 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
| 312 |
|
| 313 |
-
def safe_get(data, *keys):
|
| 314 |
for key in keys:
|
| 315 |
try:
|
| 316 |
data = data[key] if isinstance(data, (dict, list)) else data.get(key)
|
| 317 |
except (KeyError, IndexError, AttributeError, TypeError):
|
| 318 |
-
return
|
| 319 |
return data
|
|
|
|
| 310 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
| 311 |
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
| 312 |
|
| 313 |
+
def safe_get(data, *keys, default=None):
|
| 314 |
for key in keys:
|
| 315 |
try:
|
| 316 |
data = data[key] if isinstance(data, (dict, list)) else data.get(key)
|
| 317 |
except (KeyError, IndexError, AttributeError, TypeError):
|
| 318 |
+
return default
|
| 319 |
return data
|