🐛 Bug: Fix the bug where the official Claude API does not correctly pass the token count.
Browse files- response.py +19 -9
response.py
CHANGED
|
@@ -5,7 +5,7 @@ from datetime import datetime
|
|
| 5 |
from log_config import logger
|
| 6 |
|
| 7 |
|
| 8 |
-
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None,
|
| 9 |
sample_data = {
|
| 10 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
| 11 |
"object": "chat.completion.chunk",
|
|
@@ -29,6 +29,10 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
| 29 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
| 30 |
if role:
|
| 31 |
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
| 33 |
|
| 34 |
# 构建SSE响应
|
|
@@ -68,7 +72,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
| 68 |
json_data = json.loads( "{" + line + "}")
|
| 69 |
content = json_data.get('text', '')
|
| 70 |
content = "\n".join(content.split("\\n"))
|
| 71 |
-
sse_string = await generate_sse_response(timestamp, model, content)
|
| 72 |
yield sse_string
|
| 73 |
except json.JSONDecodeError:
|
| 74 |
logger.error(f"无法解析JSON: {line}")
|
|
@@ -114,7 +118,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
| 114 |
json_data = json.loads( "{" + line + "}")
|
| 115 |
content = json_data.get('text', '')
|
| 116 |
content = "\n".join(content.split("\\n"))
|
| 117 |
-
sse_string = await generate_sse_response(timestamp, model, content)
|
| 118 |
yield sse_string
|
| 119 |
except json.JSONDecodeError:
|
| 120 |
logger.error(f"无法解析JSON: {line}")
|
|
@@ -163,6 +167,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
| 163 |
yield error_message
|
| 164 |
return
|
| 165 |
buffer = ""
|
|
|
|
| 166 |
async for chunk in response.aiter_text():
|
| 167 |
# logger.info(f"chunk: {repr(chunk)}")
|
| 168 |
buffer += chunk
|
|
@@ -171,20 +176,25 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
| 171 |
# logger.info(line)
|
| 172 |
|
| 173 |
if line.startswith("data:"):
|
| 174 |
-
line = line
|
| 175 |
-
if line.startswith(" "):
|
| 176 |
-
line = line[1:]
|
| 177 |
resp: dict = json.loads(line)
|
| 178 |
message = resp.get("message")
|
| 179 |
if message:
|
| 180 |
-
tokens_use = resp.get("usage")
|
| 181 |
role = message.get("role")
|
| 182 |
if role:
|
| 183 |
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
| 184 |
yield sse_string
|
|
|
|
| 185 |
if tokens_use:
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
tool_use = resp.get("content_block")
|
| 189 |
tools_id = None
|
| 190 |
function_call_name = None
|
|
|
|
| 5 |
from log_config import logger
|
| 6 |
|
| 7 |
|
| 8 |
+
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
|
| 9 |
sample_data = {
|
| 10 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
| 11 |
"object": "chat.completion.chunk",
|
|
|
|
| 29 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
| 30 |
if role:
|
| 31 |
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
| 32 |
+
if total_tokens:
|
| 33 |
+
total_tokens = prompt_tokens + completion_tokens
|
| 34 |
+
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens,"total_tokens": total_tokens}
|
| 35 |
+
sample_data["choices"] = []
|
| 36 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
| 37 |
|
| 38 |
# 构建SSE响应
|
|
|
|
| 72 |
json_data = json.loads( "{" + line + "}")
|
| 73 |
content = json_data.get('text', '')
|
| 74 |
content = "\n".join(content.split("\\n"))
|
| 75 |
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
| 76 |
yield sse_string
|
| 77 |
except json.JSONDecodeError:
|
| 78 |
logger.error(f"无法解析JSON: {line}")
|
|
|
|
| 118 |
json_data = json.loads( "{" + line + "}")
|
| 119 |
content = json_data.get('text', '')
|
| 120 |
content = "\n".join(content.split("\\n"))
|
| 121 |
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
| 122 |
yield sse_string
|
| 123 |
except json.JSONDecodeError:
|
| 124 |
logger.error(f"无法解析JSON: {line}")
|
|
|
|
| 167 |
yield error_message
|
| 168 |
return
|
| 169 |
buffer = ""
|
| 170 |
+
input_tokens = 0
|
| 171 |
async for chunk in response.aiter_text():
|
| 172 |
# logger.info(f"chunk: {repr(chunk)}")
|
| 173 |
buffer += chunk
|
|
|
|
| 176 |
# logger.info(line)
|
| 177 |
|
| 178 |
if line.startswith("data:"):
|
| 179 |
+
line = line.lstrip("data: ")
|
|
|
|
|
|
|
| 180 |
resp: dict = json.loads(line)
|
| 181 |
message = resp.get("message")
|
| 182 |
if message:
|
|
|
|
| 183 |
role = message.get("role")
|
| 184 |
if role:
|
| 185 |
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
| 186 |
yield sse_string
|
| 187 |
+
tokens_use = message.get("usage")
|
| 188 |
if tokens_use:
|
| 189 |
+
input_tokens = tokens_use.get("input_tokens", 0)
|
| 190 |
+
usage = resp.get("usage")
|
| 191 |
+
if usage:
|
| 192 |
+
output_tokens = usage.get("output_tokens", 0)
|
| 193 |
+
total_tokens = input_tokens + output_tokens
|
| 194 |
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
|
| 195 |
+
yield sse_string
|
| 196 |
+
# print("\n\rtotal_tokens", total_tokens)
|
| 197 |
+
|
| 198 |
tool_use = resp.get("content_block")
|
| 199 |
tools_id = None
|
| 200 |
function_call_name = None
|