Files
live-voice-chat/server/main.py
T
2026-04-07 03:58:35 -04:00

88 lines
2.5 KiB
Python

import json
import logging
import os
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.params import Form
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from server.audio_utils import pcm_bytes_to_float32
from server.models import ModelManager
from server.pipeline import ConversationSession
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
model_mgr = ModelManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("Starting model loading...")
model_mgr.load_all()
log.info("Server ready.")
yield
log.info("Shutting down.")
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.get("/")
async def index():
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
@app.post("/api/set-voice")
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
"""Change the TTS voice."""
model_mgr.tts_engine.set_voice(voice, lang)
return {"status": "ok", "voice": voice}
@app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket):
await ws.accept()
log.info("WebSocket client connected.")
async def send_json(data: dict):
await ws.send_text(json.dumps(data))
async def send_bytes(data: bytes):
await ws.send_bytes(data)
session = ConversationSession(model_mgr, send_json, send_bytes)
await session.start()
try:
while True:
message = await ws.receive()
if "bytes" in message:
pcm_data = message["bytes"]
chunk = pcm_bytes_to_float32(pcm_data)
await session.handle_audio_chunk(chunk)
elif "text" in message:
try:
msg = json.loads(message["text"])
except json.JSONDecodeError:
continue
if msg.get("type") == "interrupt":
await session.interrupt()
except WebSocketDisconnect:
log.info("WebSocket client disconnected.")
except Exception:
log.exception("WebSocket error")
finally:
await session.stop()