import json import logging import os from contextlib import asynccontextmanager import numpy as np from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect from fastapi.params import Form from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from server.audio_utils import pcm_bytes_to_float32 from server.models import ModelManager from server.pipeline import ConversationSession logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") log = logging.getLogger(__name__) REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio") STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") model_mgr = ModelManager() @asynccontextmanager async def lifespan(app: FastAPI): log.info("Starting model loading...") model_mgr.load_all() log.info("Server ready.") yield log.info("Shutting down.") app = FastAPI(lifespan=lifespan) app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.get("/") async def index(): return FileResponse(os.path.join(STATIC_DIR, "index.html")) @app.post("/api/set-voice") async def set_voice(voice: str = Form(...), lang: str = Form("a")): """Change the TTS voice.""" model_mgr.tts_engine.set_voice(voice, lang) return {"status": "ok", "voice": voice} @app.websocket("/ws/chat") async def websocket_chat(ws: WebSocket): await ws.accept() log.info("WebSocket client connected.") async def send_json(data: dict): await ws.send_text(json.dumps(data)) async def send_bytes(data: bytes): await ws.send_bytes(data) session = ConversationSession(model_mgr, send_json, send_bytes) await session.start() try: while True: message = await ws.receive() if "bytes" in message: pcm_data = message["bytes"] chunk = pcm_bytes_to_float32(pcm_data) await session.handle_audio_chunk(chunk) elif "text" in message: try: msg = json.loads(message["text"]) except json.JSONDecodeError: continue if msg.get("type") == "interrupt": await session.interrupt( last_chunk_id=msg.get("last_chunk_id") ) except WebSocketDisconnect: log.info("WebSocket client disconnected.") except Exception: log.exception("WebSocket error") finally: await session.stop()