from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from model import SimpleTransformerModel, FullChatDataset, VoiceInterface, generate_response import torch import uvicorn import os from typing import Optional app = FastAPI() # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Initialize components dataset = FullChatDataset() model = SimpleTransformerModel(len(dataset.tokenizer)) voice_interface = VoiceInterface() class ChatRequest(BaseModel): prompt: str max_length: int = 100 use_voice: bool = False @app.post("/chat/") async def chat_endpoint( prompt: Optional[str] = Form(None), max_length: int = Form(100), use_voice: bool = Form(False), audio_file: Optional[UploadFile] = File(None) ): try: # Handle voice input if audio file provided if audio_file: contents = await audio_file.read() with open("temp_audio.wav", "wb") as f: f.write(contents) with sr.AudioFile("temp_audio.wav") as source: audio = voice_interface.recognizer.record(source) prompt = voice_interface.recognizer.recognize_google(audio) os.remove("temp_audio.wav") # If no prompt provided (either text or voice) if not prompt: raise HTTPException(status_code=400, detail="No input provided") response = generate_response( model, dataset.tokenizer, prompt, max_length, voice_interface if use_voice else None ) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def read_root(): return {"message": "CyberFuture Running"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)