209 lines
6.8 KiB
Python
209 lines
6.8 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
|
from fastapi.params import Form
|
|
from fastapi.responses import FileResponse, Response
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from server.audio_utils import pcm_bytes_to_float32
|
|
from server.models import ModelManager
|
|
from server.pipeline import ConversationSession
|
|
from server.video import LoRASpec
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
|
log = logging.getLogger(__name__)
|
|
|
|
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
|
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
|
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
|
|
os.makedirs(AVATAR_DIR, exist_ok=True)
|
|
|
|
model_mgr = ModelManager()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
log.info("Starting model loading...")
|
|
model_mgr.load_all()
|
|
log.info("Server ready.")
|
|
yield
|
|
log.info("Shutting down.")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|
|
|
|
|
|
@app.post("/api/set-voice")
|
|
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
|
"""Change the TTS voice."""
|
|
model_mgr.tts_engine.set_voice(voice, lang)
|
|
return {"status": "ok", "voice": voice}
|
|
|
|
|
|
# --- Video / avatar endpoints ---------------------------------------------
|
|
|
|
def _require_video() -> "object":
|
|
"""Return the video engine, or raise 404 if video mode isn't enabled."""
|
|
ve = model_mgr.video_engine
|
|
if ve is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Video engine disabled. Set config.video.enabled=true and restart.",
|
|
)
|
|
return ve
|
|
|
|
|
|
@app.post("/api/set-avatar")
|
|
async def set_avatar(image: UploadFile):
|
|
"""Upload an avatar image and (re)generate cached clips."""
|
|
ve = _require_video()
|
|
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
|
|
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
|
|
with open(dest, "wb") as f:
|
|
f.write(await image.read())
|
|
log.info("Avatar saved to %s", dest)
|
|
|
|
import asyncio
|
|
try:
|
|
await asyncio.to_thread(ve.set_avatar, dest)
|
|
except Exception as e:
|
|
log.exception("set_avatar failed")
|
|
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
|
|
|
|
return {
|
|
"status": "ok",
|
|
"avatar_path": dest,
|
|
"idle_clip_url": "/api/idle-clip",
|
|
"mode": ve.cfg.mode,
|
|
}
|
|
|
|
|
|
@app.get("/api/idle-clip")
|
|
async def idle_clip():
|
|
"""Return the cached idle loop MP4."""
|
|
ve = _require_video()
|
|
data = ve.get_idle_clip()
|
|
if data is None:
|
|
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
|
|
return Response(content=data, media_type="video/mp4")
|
|
|
|
|
|
@app.post("/api/set-video-mode")
|
|
async def set_video_mode(mode: str = Form(...)):
|
|
"""Switch between 'off', 'library', and 'reflective'.
|
|
|
|
'off' leaves the video engine loaded but makes the pipeline take the
|
|
PCM streaming path on subsequent turns (by marking the engine not-ready
|
|
from the client's perspective via a simple flag).
|
|
"""
|
|
ve = _require_video()
|
|
if mode not in ("off", "library", "reflective"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="mode must be one of: off, library, reflective",
|
|
)
|
|
# Switching between library/reflective changes how set_avatar prebakes
|
|
# clips. Require a fresh avatar upload afterwards to re-bake.
|
|
ve.cfg.mode = mode
|
|
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
|
|
|
|
|
|
@app.post("/api/reload-loras")
|
|
async def reload_loras(body: dict):
|
|
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
|
|
|
|
Regenerates the idle clip if an avatar is already set, since the new
|
|
LoRAs change the base style.
|
|
"""
|
|
ve = _require_video()
|
|
raw = body.get("loras") or []
|
|
specs: list[LoRASpec] = []
|
|
for entry in raw:
|
|
if not entry or "path" not in entry:
|
|
continue
|
|
target = str(entry.get("target", "both")).lower()
|
|
if target not in ("high_noise", "low_noise", "both"):
|
|
target = "both"
|
|
specs.append(
|
|
LoRASpec(
|
|
path=str(entry["path"]),
|
|
weight=float(entry.get("weight", 1.0)),
|
|
target=target, # type: ignore[arg-type]
|
|
name=entry.get("name"),
|
|
)
|
|
)
|
|
import asyncio
|
|
try:
|
|
await asyncio.to_thread(ve.load_loras, specs)
|
|
if ve.avatar_path:
|
|
log.info("Regenerating idle clip after LoRA reload.")
|
|
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
|
|
except Exception as e:
|
|
log.exception("reload_loras failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
|
|
|
|
|
|
@app.websocket("/ws/chat")
|
|
async def websocket_chat(ws: WebSocket):
|
|
await ws.accept()
|
|
log.info("WebSocket client connected.")
|
|
|
|
async def send_json(data: dict):
|
|
await ws.send_text(json.dumps(data))
|
|
|
|
async def send_bytes(data: bytes):
|
|
await ws.send_bytes(data)
|
|
|
|
session = ConversationSession(model_mgr, send_json, send_bytes)
|
|
await session.start()
|
|
|
|
# Tell the client whether video mode is active so it knows whether to
|
|
# suppress PCM playback and wait for speaking_clip messages instead.
|
|
ve = model_mgr.video_engine
|
|
await send_json({
|
|
"type": "video_mode",
|
|
"enabled": ve is not None,
|
|
"ready": ve.is_ready() if ve is not None else False,
|
|
"mode": ve.cfg.mode if ve is not None else "off",
|
|
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
|
|
})
|
|
|
|
try:
|
|
while True:
|
|
message = await ws.receive()
|
|
|
|
if "bytes" in message:
|
|
pcm_data = message["bytes"]
|
|
chunk = pcm_bytes_to_float32(pcm_data)
|
|
await session.handle_audio_chunk(chunk)
|
|
|
|
elif "text" in message:
|
|
try:
|
|
msg = json.loads(message["text"])
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
if msg.get("type") == "interrupt":
|
|
await session.interrupt(
|
|
last_chunk_id=msg.get("last_chunk_id")
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
log.info("WebSocket client disconnected.")
|
|
except Exception:
|
|
log.exception("WebSocket error")
|
|
finally:
|
|
await session.stop()
|