Add feature: support vertex claude API using tool use functionality.
Browse files- main.py +49 -3
- request.py +58 -85
- response.py +51 -2
- test/provider_test.py +1 -1
- utils.py +1 -0
main.py
CHANGED
|
@@ -5,7 +5,7 @@ import secrets
|
|
| 5 |
from contextlib import asynccontextmanager
|
| 6 |
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
| 9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
|
|
@@ -40,6 +40,37 @@ async def lifespan(app: FastAPI):
|
|
| 40 |
|
| 41 |
app = FastAPI(lifespan=lifespan)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# 配置 CORS 中间件
|
| 44 |
app.add_middleware(
|
| 45 |
CORSMiddleware,
|
|
@@ -219,9 +250,24 @@ def generate_api_key():
|
|
| 219 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
| 220 |
return JSONResponse(content={"api_key": api_key})
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# async def on_fetch(request, env):
|
| 223 |
# import asgi
|
| 224 |
-
|
| 225 |
# return await asgi.fetch(app, request, env)
|
| 226 |
|
| 227 |
if __name__ == '__main__':
|
|
@@ -232,5 +278,5 @@ if __name__ == '__main__':
|
|
| 232 |
port=8000,
|
| 233 |
reload=True,
|
| 234 |
ws="none",
|
| 235 |
-
log_level="warning"
|
| 236 |
)
|
|
|
|
| 5 |
from contextlib import asynccontextmanager
|
| 6 |
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi import FastAPI, HTTPException, Depends, Request
|
| 9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
|
|
|
|
| 40 |
|
| 41 |
app = FastAPI(lifespan=lifespan)
|
| 42 |
|
| 43 |
+
# from time import time
|
| 44 |
+
# from collections import defaultdict
|
| 45 |
+
# import asyncio
|
| 46 |
+
|
| 47 |
+
# class StatsMiddleware:
|
| 48 |
+
# def __init__(self):
|
| 49 |
+
# self.request_counts = defaultdict(int)
|
| 50 |
+
# self.request_times = defaultdict(float)
|
| 51 |
+
# self.ip_counts = defaultdict(lambda: defaultdict(int))
|
| 52 |
+
# self.lock = asyncio.Lock()
|
| 53 |
+
|
| 54 |
+
# async def __call__(self, request: Request, call_next):
|
| 55 |
+
# start_time = time()
|
| 56 |
+
# response = await call_next(request)
|
| 57 |
+
# process_time = time() - start_time
|
| 58 |
+
|
| 59 |
+
# endpoint = f"{request.method} {request.url.path}"
|
| 60 |
+
# client_ip = request.client.host
|
| 61 |
+
|
| 62 |
+
# async with self.lock:
|
| 63 |
+
# self.request_counts[endpoint] += 1
|
| 64 |
+
# self.request_times[endpoint] += process_time
|
| 65 |
+
# self.ip_counts[endpoint][client_ip] += 1
|
| 66 |
+
|
| 67 |
+
# return response
|
| 68 |
+
# # 创建 StatsMiddleware 实例
|
| 69 |
+
# stats_middleware = StatsMiddleware()
|
| 70 |
+
|
| 71 |
+
# # 添加 StatsMiddleware
|
| 72 |
+
# app.add_middleware(StatsMiddleware)
|
| 73 |
+
|
| 74 |
# 配置 CORS 中间件
|
| 75 |
app.add_middleware(
|
| 76 |
CORSMiddleware,
|
|
|
|
| 250 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
| 251 |
return JSONResponse(content={"api_key": api_key})
|
| 252 |
|
| 253 |
+
# @app.get("/stats")
|
| 254 |
+
# async def get_stats(token: str = Depends(verify_api_key)):
|
| 255 |
+
# async with stats_middleware.lock:
|
| 256 |
+
# return {
|
| 257 |
+
# "request_counts": dict(stats_middleware.request_counts),
|
| 258 |
+
# "average_request_times": {
|
| 259 |
+
# endpoint: total_time / count
|
| 260 |
+
# for endpoint, total_time in stats_middleware.request_times.items()
|
| 261 |
+
# for count in [stats_middleware.request_counts[endpoint]]
|
| 262 |
+
# },
|
| 263 |
+
# "ip_counts": {
|
| 264 |
+
# endpoint: dict(ips)
|
| 265 |
+
# for endpoint, ips in stats_middleware.ip_counts.items()
|
| 266 |
+
# }
|
| 267 |
+
# }
|
| 268 |
+
|
| 269 |
# async def on_fetch(request, env):
|
| 270 |
# import asgi
|
|
|
|
| 271 |
# return await asgi.fetch(app, request, env)
|
| 272 |
|
| 273 |
if __name__ == '__main__':
|
|
|
|
| 278 |
port=8000,
|
| 279 |
reload=True,
|
| 280 |
ws="none",
|
| 281 |
+
# log_level="warning"
|
| 282 |
)
|
request.py
CHANGED
|
@@ -363,7 +363,7 @@ async def get_vertex_gemini_payload(request, engine, provider):
|
|
| 363 |
|
| 364 |
async def get_vertex_claude_payload(request, engine, provider):
|
| 365 |
headers = {
|
| 366 |
-
'Content-Type': 'application/json'
|
| 367 |
}
|
| 368 |
if provider.get("client_email") and provider.get("private_key"):
|
| 369 |
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
|
@@ -386,12 +386,10 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
| 386 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
| 387 |
|
| 388 |
messages = []
|
| 389 |
-
|
| 390 |
-
function_arguments = None
|
| 391 |
for msg in request.messages:
|
| 392 |
-
if msg.role == "assistant":
|
| 393 |
-
msg.role = "model"
|
| 394 |
tool_calls = None
|
|
|
|
| 395 |
if isinstance(msg.content, list):
|
| 396 |
content = []
|
| 397 |
for item in msg.content:
|
|
@@ -402,109 +400,84 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
| 402 |
image_message = await get_image_message(item.image_url.url, engine)
|
| 403 |
content.append(image_message)
|
| 404 |
else:
|
| 405 |
-
content =
|
| 406 |
tool_calls = msg.tool_calls
|
|
|
|
| 407 |
|
| 408 |
if tool_calls:
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
| 412 |
"name": tool_call.function.name,
|
| 413 |
-
"
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
)
|
| 422 |
-
elif msg.role == "tool":
|
| 423 |
-
function_call_name = function_arguments["functionCall"]["name"]
|
| 424 |
-
messages.append(
|
| 425 |
-
{
|
| 426 |
-
"role": "function",
|
| 427 |
-
"parts": [{
|
| 428 |
-
"functionResponse": {
|
| 429 |
-
"name": function_call_name,
|
| 430 |
-
"response": {
|
| 431 |
-
"name": function_call_name,
|
| 432 |
-
"content": {
|
| 433 |
-
"result": msg.content,
|
| 434 |
-
}
|
| 435 |
-
}
|
| 436 |
-
}
|
| 437 |
-
}]
|
| 438 |
-
}
|
| 439 |
-
)
|
| 440 |
elif msg.role != "system":
|
| 441 |
-
messages.append({"role": msg.role, "
|
| 442 |
elif msg.role == "system":
|
| 443 |
-
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
|
|
|
| 446 |
payload = {
|
| 447 |
-
"
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
# "category": "HARM_CATEGORY_HARASSMENT",
|
| 451 |
-
# "threshold": "BLOCK_NONE"
|
| 452 |
-
# },
|
| 453 |
-
# {
|
| 454 |
-
# "category": "HARM_CATEGORY_HATE_SPEECH",
|
| 455 |
-
# "threshold": "BLOCK_NONE"
|
| 456 |
-
# },
|
| 457 |
-
# {
|
| 458 |
-
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 459 |
-
# "threshold": "BLOCK_NONE"
|
| 460 |
-
# },
|
| 461 |
-
# {
|
| 462 |
-
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 463 |
-
# "threshold": "BLOCK_NONE"
|
| 464 |
-
# }
|
| 465 |
-
# ]
|
| 466 |
-
"generationConfig": {
|
| 467 |
-
"temperature": 0.5,
|
| 468 |
-
"max_output_tokens": 8192,
|
| 469 |
-
"top_k": 40,
|
| 470 |
-
"top_p": 0.95
|
| 471 |
-
},
|
| 472 |
}
|
| 473 |
-
if systemInstruction:
|
| 474 |
-
payload["system_instruction"] = systemInstruction
|
| 475 |
|
| 476 |
miss_fields = [
|
| 477 |
'model',
|
| 478 |
'messages',
|
| 479 |
-
'stream',
|
| 480 |
-
'tool_choice',
|
| 481 |
-
'temperature',
|
| 482 |
-
'top_p',
|
| 483 |
-
'max_tokens',
|
| 484 |
'presence_penalty',
|
| 485 |
'frequency_penalty',
|
| 486 |
'n',
|
| 487 |
'user',
|
| 488 |
'include_usage',
|
| 489 |
-
'logprobs',
|
| 490 |
-
'top_logprobs'
|
| 491 |
]
|
| 492 |
|
| 493 |
for field, value in request.model_dump(exclude_unset=True).items():
|
| 494 |
if field not in miss_fields and value is not None:
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
return url, headers, payload
|
| 510 |
|
|
|
|
| 363 |
|
| 364 |
async def get_vertex_claude_payload(request, engine, provider):
|
| 365 |
headers = {
|
| 366 |
+
'Content-Type': 'application/json',
|
| 367 |
}
|
| 368 |
if provider.get("client_email") and provider.get("private_key"):
|
| 369 |
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
|
|
|
| 386 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
| 387 |
|
| 388 |
messages = []
|
| 389 |
+
system_prompt = None
|
|
|
|
| 390 |
for msg in request.messages:
|
|
|
|
|
|
|
| 391 |
tool_calls = None
|
| 392 |
+
tool_call_id = None
|
| 393 |
if isinstance(msg.content, list):
|
| 394 |
content = []
|
| 395 |
for item in msg.content:
|
|
|
|
| 400 |
image_message = await get_image_message(item.image_url.url, engine)
|
| 401 |
content.append(image_message)
|
| 402 |
else:
|
| 403 |
+
content = msg.content
|
| 404 |
tool_calls = msg.tool_calls
|
| 405 |
+
tool_call_id = msg.tool_call_id
|
| 406 |
|
| 407 |
if tool_calls:
|
| 408 |
+
tool_calls_list = []
|
| 409 |
+
for tool_call in tool_calls:
|
| 410 |
+
tool_calls_list.append({
|
| 411 |
+
"type": "tool_use",
|
| 412 |
+
"id": tool_call.id,
|
| 413 |
"name": tool_call.function.name,
|
| 414 |
+
"input": json.loads(tool_call.function.arguments),
|
| 415 |
+
})
|
| 416 |
+
messages.append({"role": msg.role, "content": tool_calls_list})
|
| 417 |
+
elif tool_call_id:
|
| 418 |
+
messages.append({"role": "user", "content": [{
|
| 419 |
+
"type": "tool_result",
|
| 420 |
+
"tool_use_id": tool_call.id,
|
| 421 |
+
"content": content
|
| 422 |
+
}]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
elif msg.role != "system":
|
| 424 |
+
messages.append({"role": msg.role, "content": content})
|
| 425 |
elif msg.role == "system":
|
| 426 |
+
system_prompt = content
|
| 427 |
|
| 428 |
+
conversation_len = len(messages) - 1
|
| 429 |
+
message_index = 0
|
| 430 |
+
while message_index < conversation_len:
|
| 431 |
+
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
|
| 432 |
+
if messages[message_index].get("content"):
|
| 433 |
+
if isinstance(messages[message_index]["content"], list):
|
| 434 |
+
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
|
| 435 |
+
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
|
| 436 |
+
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
|
| 437 |
+
content_list.extend(messages[message_index + 1]["content"])
|
| 438 |
+
messages[message_index]["content"] = content_list
|
| 439 |
+
else:
|
| 440 |
+
messages[message_index]["content"] += messages[message_index + 1]["content"]
|
| 441 |
+
messages.pop(message_index + 1)
|
| 442 |
+
conversation_len = conversation_len - 1
|
| 443 |
+
else:
|
| 444 |
+
message_index = message_index + 1
|
| 445 |
|
| 446 |
+
model = provider['model'][request.model]
|
| 447 |
payload = {
|
| 448 |
+
"anthropic_version": "vertex-2023-10-16",
|
| 449 |
+
"messages": messages,
|
| 450 |
+
"system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
}
|
|
|
|
|
|
|
| 452 |
|
| 453 |
miss_fields = [
|
| 454 |
'model',
|
| 455 |
'messages',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
'presence_penalty',
|
| 457 |
'frequency_penalty',
|
| 458 |
'n',
|
| 459 |
'user',
|
| 460 |
'include_usage',
|
|
|
|
|
|
|
| 461 |
]
|
| 462 |
|
| 463 |
for field, value in request.model_dump(exclude_unset=True).items():
|
| 464 |
if field not in miss_fields and value is not None:
|
| 465 |
+
payload[field] = value
|
| 466 |
+
|
| 467 |
+
if request.tools and provider.get("tools"):
|
| 468 |
+
tools = []
|
| 469 |
+
for tool in request.tools:
|
| 470 |
+
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
| 471 |
+
tools.append(json_tool)
|
| 472 |
+
payload["tools"] = tools
|
| 473 |
+
if "tool_choice" in payload:
|
| 474 |
+
payload["tool_choice"] = {
|
| 475 |
+
"type": "auto"
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
if provider.get("tools") == False:
|
| 479 |
+
payload.pop("tools", None)
|
| 480 |
+
payload.pop("tool_choice", None)
|
| 481 |
|
| 482 |
return url, headers, payload
|
| 483 |
|
response.py
CHANGED
|
@@ -84,6 +84,55 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
| 84 |
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
| 85 |
yield sse_string
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
|
| 88 |
redirect_count = 0
|
| 89 |
while redirect_count < max_redirects:
|
|
@@ -202,10 +251,10 @@ async def fetch_response(client, url, headers, payload):
|
|
| 202 |
|
| 203 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 204 |
try:
|
| 205 |
-
if engine == "gemini" or engine == "vertex":
|
| 206 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 207 |
yield chunk
|
| 208 |
-
elif engine == "claude":
|
| 209 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
| 210 |
yield chunk
|
| 211 |
elif engine == "gpt":
|
|
|
|
| 84 |
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
| 85 |
yield sse_string
|
| 86 |
|
| 87 |
+
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
| 88 |
+
timestamp = datetime.timestamp(datetime.now())
|
| 89 |
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
| 90 |
+
if response.status_code != 200:
|
| 91 |
+
error_message = await response.aread()
|
| 92 |
+
error_str = error_message.decode('utf-8', errors='replace')
|
| 93 |
+
try:
|
| 94 |
+
error_json = json.loads(error_str)
|
| 95 |
+
except json.JSONDecodeError:
|
| 96 |
+
error_json = error_str
|
| 97 |
+
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
| 98 |
+
buffer = ""
|
| 99 |
+
revicing_function_call = False
|
| 100 |
+
function_full_response = "{"
|
| 101 |
+
need_function_call = False
|
| 102 |
+
async for chunk in response.aiter_text():
|
| 103 |
+
buffer += chunk
|
| 104 |
+
while "\n" in buffer:
|
| 105 |
+
line, buffer = buffer.split("\n", 1)
|
| 106 |
+
logger.info(f"{line}")
|
| 107 |
+
if line and '\"text\": \"' in line:
|
| 108 |
+
try:
|
| 109 |
+
json_data = json.loads( "{" + line + "}")
|
| 110 |
+
content = json_data.get('text', '')
|
| 111 |
+
content = "\n".join(content.split("\\n"))
|
| 112 |
+
sse_string = await generate_sse_response(timestamp, model, content)
|
| 113 |
+
yield sse_string
|
| 114 |
+
except json.JSONDecodeError:
|
| 115 |
+
logger.error(f"无法解析JSON: {line}")
|
| 116 |
+
|
| 117 |
+
if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
|
| 118 |
+
revicing_function_call = True
|
| 119 |
+
need_function_call = True
|
| 120 |
+
if ']' in line:
|
| 121 |
+
revicing_function_call = False
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
function_full_response += line
|
| 125 |
+
|
| 126 |
+
if need_function_call:
|
| 127 |
+
function_call = json.loads(function_full_response)
|
| 128 |
+
function_call_name = function_call["name"]
|
| 129 |
+
function_call_id = function_call["id"]
|
| 130 |
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
|
| 131 |
+
yield sse_string
|
| 132 |
+
function_full_response = json.dumps(function_call["input"])
|
| 133 |
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
|
| 134 |
+
yield sse_string
|
| 135 |
+
|
| 136 |
async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
|
| 137 |
redirect_count = 0
|
| 138 |
while redirect_count < max_redirects:
|
|
|
|
| 251 |
|
| 252 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
| 253 |
try:
|
| 254 |
+
if engine == "gemini" or (engine == "vertex" and "gemini" in model):
|
| 255 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
| 256 |
yield chunk
|
| 257 |
+
elif engine == "claude" or (engine == "vertex" and "claude" in model):
|
| 258 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
| 259 |
yield chunk
|
| 260 |
elif engine == "gpt":
|
test/provider_test.py
CHANGED
|
@@ -80,7 +80,7 @@ def test_request_model(test_client, api_key, get_model):
|
|
| 80 |
|
| 81 |
response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
|
| 82 |
for line in response.iter_lines():
|
| 83 |
-
print(line)
|
| 84 |
assert response.status_code == 200
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
|
|
|
| 80 |
|
| 81 |
response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
|
| 82 |
for line in response.iter_lines():
|
| 83 |
+
print(line.lstrip("data: "))
|
| 84 |
assert response.status_code == 200
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
utils.py
CHANGED
|
@@ -80,6 +80,7 @@ async def error_handling_wrapper(generator, status_code=200):
|
|
| 80 |
try:
|
| 81 |
first_item = await generator.__anext__()
|
| 82 |
first_item_str = first_item
|
|
|
|
| 83 |
if isinstance(first_item_str, (bytes, bytearray)):
|
| 84 |
first_item_str = first_item_str.decode("utf-8")
|
| 85 |
if isinstance(first_item_str, str):
|
|
|
|
| 80 |
try:
|
| 81 |
first_item = await generator.__anext__()
|
| 82 |
first_item_str = first_item
|
| 83 |
+
# logger.info("first_item_str: %s", first_item_str)
|
| 84 |
if isinstance(first_item_str, (bytes, bytearray)):
|
| 85 |
first_item_str = first_item_str.decode("utf-8")
|
| 86 |
if isinstance(first_item_str, str):
|