Files
live-voice-chat/server/main.py
T
2026-04-16 10:00:37 -04:00

209 lines
6.8 KiB
Python

import json
import logging
import os
import tempfile
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.params import Form
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from server.audio_utils import pcm_bytes_to_float32
from server.models import ModelManager
from server.pipeline import ConversationSession
from server.video import LoRASpec
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
os.makedirs(AVATAR_DIR, exist_ok=True)
model_mgr = ModelManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("Starting model loading...")
model_mgr.load_all()
log.info("Server ready.")
yield
log.info("Shutting down.")
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.get("/")
async def index():
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
@app.post("/api/set-voice")
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
"""Change the TTS voice."""
model_mgr.tts_engine.set_voice(voice, lang)
return {"status": "ok", "voice": voice}
# --- Video / avatar endpoints ---------------------------------------------
def _require_video() -> "object":
"""Return the video engine, or raise 404 if video mode isn't enabled."""
ve = model_mgr.video_engine
if ve is None:
raise HTTPException(
status_code=404,
detail="Video engine disabled. Set config.video.enabled=true and restart.",
)
return ve
@app.post("/api/set-avatar")
async def set_avatar(image: UploadFile):
"""Upload an avatar image and (re)generate cached clips."""
ve = _require_video()
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
with open(dest, "wb") as f:
f.write(await image.read())
log.info("Avatar saved to %s", dest)
import asyncio
try:
await asyncio.to_thread(ve.set_avatar, dest)
except Exception as e:
log.exception("set_avatar failed")
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
return {
"status": "ok",
"avatar_path": dest,
"idle_clip_url": "/api/idle-clip",
"mode": ve.cfg.mode,
}
@app.get("/api/idle-clip")
async def idle_clip():
"""Return the cached idle loop MP4."""
ve = _require_video()
data = ve.get_idle_clip()
if data is None:
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
return Response(content=data, media_type="video/mp4")
@app.post("/api/set-video-mode")
async def set_video_mode(mode: str = Form(...)):
"""Switch between 'off', 'library', and 'reflective'.
'off' leaves the video engine loaded but makes the pipeline take the
PCM streaming path on subsequent turns (by marking the engine not-ready
from the client's perspective via a simple flag).
"""
ve = _require_video()
if mode not in ("off", "library", "reflective"):
raise HTTPException(
status_code=400,
detail="mode must be one of: off, library, reflective",
)
# Switching between library/reflective changes how set_avatar prebakes
# clips. Require a fresh avatar upload afterwards to re-bake.
ve.cfg.mode = mode
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
@app.post("/api/reload-loras")
async def reload_loras(body: dict):
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
Regenerates the idle clip if an avatar is already set, since the new
LoRAs change the base style.
"""
ve = _require_video()
raw = body.get("loras") or []
specs: list[LoRASpec] = []
for entry in raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target != "both":
target = "both"
specs.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
import asyncio
try:
await asyncio.to_thread(ve.load_loras, specs)
if ve.avatar_path:
log.info("Regenerating idle clip after LoRA reload.")
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
except Exception as e:
log.exception("reload_loras failed")
raise HTTPException(status_code=500, detail=str(e))
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
@app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket):
await ws.accept()
log.info("WebSocket client connected.")
async def send_json(data: dict):
await ws.send_text(json.dumps(data))
async def send_bytes(data: bytes):
await ws.send_bytes(data)
session = ConversationSession(model_mgr, send_json, send_bytes)
await session.start()
# Tell the client whether video mode is active so it knows whether to
# suppress PCM playback and wait for speaking_clip messages instead.
ve = model_mgr.video_engine
await send_json({
"type": "video_mode",
"enabled": ve is not None,
"ready": ve.is_ready() if ve is not None else False,
"mode": ve.cfg.mode if ve is not None else "off",
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
})
try:
while True:
message = await ws.receive()
if "bytes" in message:
pcm_data = message["bytes"]
chunk = pcm_bytes_to_float32(pcm_data)
await session.handle_audio_chunk(chunk)
elif "text" in message:
try:
msg = json.loads(message["text"])
except json.JSONDecodeError:
continue
if msg.get("type") == "interrupt":
await session.interrupt(
last_chunk_id=msg.get("last_chunk_id")
)
except WebSocketDisconnect:
log.info("WebSocket client disconnected.")
except Exception:
log.exception("WebSocket error")
finally:
await session.stop()