🐛 Bug: Fix bugs caused by concurrent errors in multiple database write operations.
Browse files
main.py
CHANGED
|
@@ -282,40 +282,67 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
|
|
| 282 |
# 返回精确到15位小数的结果
|
| 283 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
async def update_stats(current_info):
|
| 286 |
if DISABLE_DATABASE:
|
| 287 |
return
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
async def update_channel_stats(request_id, provider, model, api_key, success):
|
| 302 |
if DISABLE_DATABASE:
|
| 303 |
return
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
class LoggingStreamingResponse(Response):
|
| 321 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|
|
|
|
| 282 |
# 返回精确到15位小数的结果
|
| 283 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
| 284 |
|
| 285 |
+
from asyncio import Semaphore
|
| 286 |
+
|
| 287 |
+
# 创建一个信号量来控制数据库访问
|
| 288 |
+
db_semaphore = Semaphore(1) # 限制同时只有1个写入操作
|
| 289 |
+
|
| 290 |
async def update_stats(current_info):
|
| 291 |
if DISABLE_DATABASE:
|
| 292 |
return
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
# 等待获取数据库访问权限
|
| 296 |
+
async with db_semaphore:
|
| 297 |
+
async with async_session() as session:
|
| 298 |
+
async with session.begin():
|
| 299 |
+
try:
|
| 300 |
+
columns = [column.key for column in RequestStat.__table__.columns]
|
| 301 |
+
filtered_info = {k: v for k, v in current_info.items() if k in columns}
|
| 302 |
+
new_request_stat = RequestStat(**filtered_info)
|
| 303 |
+
session.add(new_request_stat)
|
| 304 |
+
await session.commit()
|
| 305 |
+
except Exception as e:
|
| 306 |
+
await session.rollback()
|
| 307 |
+
logger.error(f"Error updating stats: {str(e)}")
|
| 308 |
+
if is_debug:
|
| 309 |
+
import traceback
|
| 310 |
+
traceback.print_exc()
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"Error acquiring database lock: {str(e)}")
|
| 313 |
+
if is_debug:
|
| 314 |
+
import traceback
|
| 315 |
+
traceback.print_exc()
|
| 316 |
|
| 317 |
async def update_channel_stats(request_id, provider, model, api_key, success):
|
| 318 |
if DISABLE_DATABASE:
|
| 319 |
return
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
async with db_semaphore:
|
| 323 |
+
async with async_session() as session:
|
| 324 |
+
async with session.begin():
|
| 325 |
+
try:
|
| 326 |
+
channel_stat = ChannelStat(
|
| 327 |
+
request_id=request_id,
|
| 328 |
+
provider=provider,
|
| 329 |
+
model=model,
|
| 330 |
+
api_key=api_key,
|
| 331 |
+
success=success,
|
| 332 |
+
)
|
| 333 |
+
session.add(channel_stat)
|
| 334 |
+
await session.commit()
|
| 335 |
+
except Exception as e:
|
| 336 |
+
await session.rollback()
|
| 337 |
+
logger.error(f"Error updating channel stats: {str(e)}")
|
| 338 |
+
if is_debug:
|
| 339 |
+
import traceback
|
| 340 |
+
traceback.print_exc()
|
| 341 |
+
except Exception as e:
|
| 342 |
+
logger.error(f"Error acquiring database lock: {str(e)}")
|
| 343 |
+
if is_debug:
|
| 344 |
+
import traceback
|
| 345 |
+
traceback.print_exc()
|
| 346 |
|
| 347 |
class LoggingStreamingResponse(Response):
|
| 348 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|