🐛 Bug: Fix the bug of tool use request body format error
Browse files- .gitignore +3 -1
- models.py +9 -4
- request.py +30 -26
- test/test_nostream.py +2 -3
.gitignore
CHANGED
|
@@ -5,4 +5,6 @@ __pycache__
|
|
| 5 |
.vscode
|
| 6 |
node_modules
|
| 7 |
.wrangler
|
| 8 |
-
.pytest_cache
|
|
|
|
|
|
|
|
|
| 5 |
.vscode
|
| 6 |
node_modules
|
| 7 |
.wrangler
|
| 8 |
+
.pytest_cache
|
| 9 |
+
*.jpg
|
| 10 |
+
*.json
|
models.py
CHANGED
|
@@ -10,16 +10,14 @@ class ImageGenerationRequest(BaseModel):
|
|
| 10 |
|
| 11 |
class FunctionParameter(BaseModel):
|
| 12 |
type: str
|
| 13 |
-
properties: Dict[str, Dict[str, str]]
|
| 14 |
required: List[str]
|
| 15 |
|
| 16 |
-
# 定义 Function 模型
|
| 17 |
class Function(BaseModel):
|
| 18 |
name: str
|
| 19 |
description: str
|
| 20 |
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
|
| 21 |
|
| 22 |
-
# 定义 Tool 模型
|
| 23 |
class Tool(BaseModel):
|
| 24 |
type: str
|
| 25 |
function: Function
|
|
@@ -58,6 +56,13 @@ class Message(BaseModel):
|
|
| 58 |
class Config:
|
| 59 |
extra = "allow" # 允许额外的字段
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
class RequestModel(BaseModel):
|
| 62 |
model: str
|
| 63 |
messages: List[Message]
|
|
@@ -72,5 +77,5 @@ class RequestModel(BaseModel):
|
|
| 72 |
frequency_penalty: Optional[float] = 0.0
|
| 73 |
n: Optional[int] = 1
|
| 74 |
user: Optional[str] = None
|
| 75 |
-
tool_choice: Optional[str] = None
|
| 76 |
tools: Optional[List[Tool]] = None
|
|
|
|
| 10 |
|
| 11 |
class FunctionParameter(BaseModel):
|
| 12 |
type: str
|
| 13 |
+
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
| 14 |
required: List[str]
|
| 15 |
|
|
|
|
| 16 |
class Function(BaseModel):
|
| 17 |
name: str
|
| 18 |
description: str
|
| 19 |
parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)
|
| 20 |
|
|
|
|
| 21 |
class Tool(BaseModel):
|
| 22 |
type: str
|
| 23 |
function: Function
|
|
|
|
| 56 |
class Config:
|
| 57 |
extra = "allow" # 允许额外的字段
|
| 58 |
|
| 59 |
+
class FunctionChoice(BaseModel):
|
| 60 |
+
name: str
|
| 61 |
+
|
| 62 |
+
class ToolChoice(BaseModel):
|
| 63 |
+
type: str
|
| 64 |
+
function: Optional[FunctionChoice] = None
|
| 65 |
+
|
| 66 |
class RequestModel(BaseModel):
|
| 67 |
model: str
|
| 68 |
messages: List[Message]
|
|
|
|
| 77 |
frequency_penalty: Optional[float] = 0.0
|
| 78 |
n: Optional[int] = 1
|
| 79 |
user: Optional[str] = None
|
| 80 |
+
tool_choice: Optional[Union[str, ToolChoice]] = None
|
| 81 |
tools: Optional[List[Tool]] = None
|
request.py
CHANGED
|
@@ -474,19 +474,21 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
| 474 |
tools.append(json_tool)
|
| 475 |
payload["tools"] = tools
|
| 476 |
if "tool_choice" in payload:
|
| 477 |
-
if payload["tool_choice"]
|
| 478 |
-
payload["tool_choice"]
|
| 479 |
-
"
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
| 490 |
|
| 491 |
if provider.get("tools") == False:
|
| 492 |
payload.pop("tools", None)
|
|
@@ -746,19 +748,21 @@ async def get_claude_payload(request, engine, provider):
|
|
| 746 |
tools.append(json_tool)
|
| 747 |
payload["tools"] = tools
|
| 748 |
if "tool_choice" in payload:
|
| 749 |
-
if payload["tool_choice"]
|
| 750 |
-
payload["tool_choice"]
|
| 751 |
-
"
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
|
|
|
|
|
|
| 762 |
|
| 763 |
if provider.get("tools") == False:
|
| 764 |
payload.pop("tools", None)
|
|
|
|
| 474 |
tools.append(json_tool)
|
| 475 |
payload["tools"] = tools
|
| 476 |
if "tool_choice" in payload:
|
| 477 |
+
if isinstance(payload["tool_choice"], dict):
|
| 478 |
+
if payload["tool_choice"]["type"] == "function":
|
| 479 |
+
payload["tool_choice"] = {
|
| 480 |
+
"type": "tool",
|
| 481 |
+
"name": payload["tool_choice"]["function"]["name"]
|
| 482 |
+
}
|
| 483 |
+
if isinstance(payload["tool_choice"], str):
|
| 484 |
+
if payload["tool_choice"] == "auto":
|
| 485 |
+
payload["tool_choice"] = {
|
| 486 |
+
"type": "auto"
|
| 487 |
+
}
|
| 488 |
+
if payload["tool_choice"] == "none":
|
| 489 |
+
payload["tool_choice"] = {
|
| 490 |
+
"type": "any"
|
| 491 |
+
}
|
| 492 |
|
| 493 |
if provider.get("tools") == False:
|
| 494 |
payload.pop("tools", None)
|
|
|
|
| 748 |
tools.append(json_tool)
|
| 749 |
payload["tools"] = tools
|
| 750 |
if "tool_choice" in payload:
|
| 751 |
+
if isinstance(payload["tool_choice"], dict):
|
| 752 |
+
if payload["tool_choice"]["type"] == "function":
|
| 753 |
+
payload["tool_choice"] = {
|
| 754 |
+
"type": "tool",
|
| 755 |
+
"name": payload["tool_choice"]["function"]["name"]
|
| 756 |
+
}
|
| 757 |
+
if isinstance(payload["tool_choice"], str):
|
| 758 |
+
if payload["tool_choice"] == "auto":
|
| 759 |
+
payload["tool_choice"] = {
|
| 760 |
+
"type": "auto"
|
| 761 |
+
}
|
| 762 |
+
if payload["tool_choice"] == "none":
|
| 763 |
+
payload["tool_choice"] = {
|
| 764 |
+
"type": "any"
|
| 765 |
+
}
|
| 766 |
|
| 767 |
if provider.get("tools") == False:
|
| 768 |
payload.pop("tools", None)
|
test/test_nostream.py
CHANGED
|
@@ -45,7 +45,6 @@ def get_model_response(image_base64):
|
|
| 45 |
]
|
| 46 |
|
| 47 |
payload = {
|
| 48 |
-
|
| 49 |
"model": "claude-3-5-sonnet",
|
| 50 |
"messages": [
|
| 51 |
{
|
|
@@ -64,7 +63,7 @@ def get_model_response(image_base64):
|
|
| 64 |
]
|
| 65 |
}
|
| 66 |
],
|
| 67 |
-
"stream": True,
|
| 68 |
"tools": tools,
|
| 69 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
| 70 |
"max_tokens": 300
|
|
@@ -117,5 +116,5 @@ def main(image_path):
|
|
| 117 |
print("\n無法解析回應。")
|
| 118 |
|
| 119 |
if __name__ == "__main__":
|
| 120 |
-
image_path = "
|
| 121 |
main(image_path)
|
|
|
|
| 45 |
]
|
| 46 |
|
| 47 |
payload = {
|
|
|
|
| 48 |
"model": "claude-3-5-sonnet",
|
| 49 |
"messages": [
|
| 50 |
{
|
|
|
|
| 63 |
]
|
| 64 |
}
|
| 65 |
],
|
| 66 |
+
# "stream": True,
|
| 67 |
"tools": tools,
|
| 68 |
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
|
| 69 |
"max_tokens": 300
|
|
|
|
| 116 |
print("\n無法解析回應。")
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
| 119 |
+
image_path = "1.jpg" # 替換為您的圖像路徑
|
| 120 |
main(image_path)
|