Legal_test / app.py
Jadyro's picture
Update app.py
2865c3e verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from transformers import AutoTokenizer, pipeline
MODEL_ID = "Equall/Saul-7B-Instruct-v1"
print("Loading model... this can take a while on first start.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
pipe = pipeline(
"text-generation",
model=MODEL_ID,
tokenizer=tokenizer,
device=-1, # CPU only
max_new_tokens=512,
pad_token_id=tokenizer.eos_token_id,
)
app = FastAPI()
class ChatMessage(BaseModel):
role: str # "system" | "user" | "assistant"
content: str
class ChatRequest(BaseModel):
model: Optional[str] = None # ignored, OpenAI-style compat
messages: List[ChatMessage]
temperature: Optional[float] = 0.0
max_tokens: Optional[int] = 512
@app.get("/")
def root():
return {"status": "ok", "model": MODEL_ID}
def build_prompt(raw_messages: List[dict]) -> str:
"""
Normalize messages so they fit the template:
- Collect system messages and prepend their text to the first user message.
- Drop leading assistant messages.
- Merge consecutive messages with the same role.
- Ensure we end up with user/assistant/user/assistant/... only.
"""
system_parts = []
ua_messages = []
# Separate system vs user/assistant
for m in raw_messages:
role = m.get("role")
content = m.get("content", "")
if role == "system":
if content:
system_parts.append(content)
elif role in ("user", "assistant"):
ua_messages.append({"role": role, "content": content})
# ignore anything else
# Drop leading assistants (template wants to start with user)
while ua_messages and ua_messages[0]["role"] != "user":
ua_messages.pop(0)
# Merge consecutive messages with same role
normalized: List[dict] = []
for m in ua_messages:
if not normalized:
normalized.append(m)
else:
if normalized[-1]["role"] == m["role"]:
normalized[-1]["content"] += "\n\n" + m["content"]
else:
normalized.append(m)
if not normalized:
raise ValueError("No user messages found after normalization.")
# Prepend system text into the first user message, if any
if system_parts:
system_text = "\n\n".join(system_parts)
if normalized[0]["role"] == "user":
normalized[0]["content"] = system_text + "\n\n" + normalized[0]["content"]
else:
# If for some reason first is assistant, prepend a synthetic user
normalized.insert(0, {"role": "user", "content": system_text})
# At this point we should only have user/assistant alternating.
# Let tokenizer.apply_chat_template enforce the exact format.
prompt = tokenizer.apply_chat_template(
normalized,
tokenize=False,
add_generation_prompt=True,
)
return prompt
@app.post("/debug-echo")
async def debug_echo(request: ChatRequest):
body = await request.body()
print("DEBUG ECHO BODY:", body)
return {"ok": True}
@app.post("/v1/chat/completions")
def chat(request: ChatRequest):
try:
messages = [m.dict() for m in request.messages]
prompt = build_prompt(messages)
except Exception as e:
# Don't crash the app – return a 400 with explanation
raise HTTPException(status_code=400, detail=f"Invalid message history: {e}")
outputs = pipe(
prompt,
max_new_tokens=request.max_tokens or 512,
do_sample=(request.temperature or 0.0) > 0,
temperature=request.temperature or 0.0,
top_p=1.0,
)
full = outputs[0]["generated_text"]
reply = full[len(prompt):].strip()
return {
"id": "chatcmpl-1",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": reply,
},
"finish_reason": "stop",
}
],
}