Text Generation
Transformers
English
legal
chat
transformer
SkillForge45 commited on
Commit
247bc76
·
verified ·
1 Parent(s): 8dbb41f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ app = FastAPI()
8
+
9
+
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_methods=["*"],
14
+ allow_headers=["*"],
15
+ )
16
+
17
+
18
+ model_name = "SkillForge45/CyberFuture-A1"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
21
+
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model.to(device)
25
+
26
+ class ChatRequest(BaseModel):
27
+ prompt: str
28
+ max_length: int = 100
29
+
30
+ @app.post("/chat/")
31
+ async def chat(request: ChatRequest):
32
+ try:
33
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_length=request.max_length,
37
+ temperature=0.7,
38
+ do_sample=True
39
+ )
40
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+ return {"response": response}
42
+ except Exception as e:
43
+ raise HTTPException(status_code=500, detail=str(e))
44
+
45
+ if __name__ == "__main__":
46
+ import uvicorn
47
+ uvicorn.run(app, host="0.0.0.0", port=8000)