90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
|
|
from fastapi.params import Form
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from server.audio_utils import pcm_bytes_to_float32
|
|
from server.models import ModelManager
|
|
from server.pipeline import ConversationSession
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
|
log = logging.getLogger(__name__)
|
|
|
|
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
|
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
|
|
|
model_mgr = ModelManager()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
log.info("Starting model loading...")
|
|
model_mgr.load_all()
|
|
log.info("Server ready.")
|
|
yield
|
|
log.info("Shutting down.")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|
|
|
|
|
|
@app.post("/api/set-voice")
|
|
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
|
"""Change the TTS voice."""
|
|
model_mgr.tts_engine.set_voice(voice, lang)
|
|
return {"status": "ok", "voice": voice}
|
|
|
|
|
|
@app.websocket("/ws/chat")
|
|
async def websocket_chat(ws: WebSocket):
|
|
await ws.accept()
|
|
log.info("WebSocket client connected.")
|
|
|
|
async def send_json(data: dict):
|
|
await ws.send_text(json.dumps(data))
|
|
|
|
async def send_bytes(data: bytes):
|
|
await ws.send_bytes(data)
|
|
|
|
session = ConversationSession(model_mgr, send_json, send_bytes)
|
|
await session.start()
|
|
|
|
try:
|
|
while True:
|
|
message = await ws.receive()
|
|
|
|
if "bytes" in message:
|
|
pcm_data = message["bytes"]
|
|
chunk = pcm_bytes_to_float32(pcm_data)
|
|
await session.handle_audio_chunk(chunk)
|
|
|
|
elif "text" in message:
|
|
try:
|
|
msg = json.loads(message["text"])
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
if msg.get("type") == "interrupt":
|
|
await session.interrupt(
|
|
last_chunk_id=msg.get("last_chunk_id")
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
log.info("WebSocket client disconnected.")
|
|
except Exception:
|
|
log.exception("WebSocket error")
|
|
finally:
|
|
await session.stop()
|