Compare commits

...

1 Commits

Author SHA1 Message Date
bhetherman 44a10667c2 Enhance video handling and performance optimizations
- Added environment variables to prevent CPU thread pools from busy-waiting.
- Deferred loading of video models until first use to reduce VRAM footprint.
- Implemented streaming of speaking clips for improved responsiveness.
- Introduced a queue for managing speaking clips to handle multiple requests smoothly.
- Updated video playback logic to ensure proper handling of clip generation.
2026-04-24 00:36:18 -04:00
7 changed files with 234 additions and 69 deletions
+10 -1
View File
@@ -7,6 +7,12 @@ ENV HF_HOME=/cache/huggingface
# LoRA directory — users drop .safetensors files here and reference them # LoRA directory — users drop .safetensors files here and reference them
# from config.yml::video.loras. Bind-mounted via docker-compose. # from config.yml::video.loras. Bind-mounted via docker-compose.
ENV LORA_DIR=/cache/loras ENV LORA_DIR=/cache/loras
# Prevent PyTorch/OpenMP/MKL thread pools from spin-waiting when idle.
# Without this, loading large models (ASR, LLM, Wan2.2) causes all CPU cores
# to busy-loop even with no connected clients, slowing the whole system.
ENV OMP_WAIT_POLICY=PASSIVE
ENV MKL_WAIT_POLICY=PASSIVE
ENV TOKENIZERS_PARALLELISM=false
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3.11 \ python3.11 \
@@ -50,8 +56,11 @@ RUN python3.11 -m spacy download en_core_web_sm
# LightX2V (Wan2.2-Lightning inference framework) — installed from source # LightX2V (Wan2.2-Lightning inference framework) — installed from source
# since there is no stable PyPI release yet. # since there is no stable PyPI release yet.
RUN python3.11 -m pip install --no-cache-dir \ RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/ModelTC/LightX2V.git" || \ "git+https://github.com/ModelTC/LightX2V.git@6db002f2755036b02bd0900bf9b41958bbfb4137" || \
echo "LightX2V install failed — config.video.enabled must stay false until fixed" echo "LightX2V install failed — config.video.enabled must stay false until fixed"
# ^ Pinned to 2026-04-14: last commit before WorldMirrorRunner was added to
# pipeline.py (which requires flash_attn + matplotlib) and before the
# dummy_model NameError regression in vae_2_2.py.
# #
# sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120 # sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120
# (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead. # (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead.
+10
View File
@@ -1,5 +1,15 @@
import os
import torch
import uvicorn import uvicorn
# Cap CPU thread pools so PyTorch/OpenMP don't spin-wait on every core at idle.
# Models run on GPU; the CPU thread pool is only needed for small ops.
os.environ.setdefault("OMP_WAIT_POLICY", "PASSIVE")
os.environ.setdefault("MKL_WAIT_POLICY", "PASSIVE")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
torch.set_num_threads(2)
torch.set_num_interop_threads(2)
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"server.main:app", "server.main:app",
+29 -3
View File
@@ -80,11 +80,28 @@ class LLMEngine:
f"processing {input_len - cached_len} new tokens" f"processing {input_len - cached_len} new tokens"
) )
with torch.no_grad(): # Guard: if the cache claims to have seen >= input tokens, it's
outputs = self.model.generate( # stale (can happen after barge-in races or tokenizer mismatches).
# An invalid cache causes an empty cache_position in transformers,
# which raises IndexError inside model.generate().
if past_kv is not None:
cache_seq_len = (
past_kv.get_seq_length()
if hasattr(past_kv, "get_seq_length")
else cached_len
)
if cache_seq_len >= input_len:
log.warning(
f"KV-cache stale (cache_seq={cache_seq_len} >= input={input_len}), discarding."
)
past_kv = None
cached_len = 0
def _do_generate(pkv):
return self.model.generate(
input_ids=input_ids, input_ids=input_ids,
attention_mask=inputs.get("attention_mask"), attention_mask=inputs.get("attention_mask"),
past_key_values=past_kv, past_key_values=pkv,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
@@ -94,6 +111,15 @@ class LLMEngine:
use_cache=True, use_cache=True,
) )
with torch.no_grad():
try:
outputs = _do_generate(past_kv)
except IndexError:
log.warning("KV-cache caused IndexError during generate; retrying without cache.")
past_kv = None
cached_len = 0
outputs = _do_generate(None)
# Decode only the generated tokens (skip prompt) # Decode only the generated tokens (skip prompt)
new_ids = outputs.sequences[0][input_len:] new_ids = outputs.sequences[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
+5 -3
View File
@@ -118,11 +118,13 @@ class ModelManager:
log.info("Video engine disabled (config.video.enabled=false). Skipping load.") log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
return return
log.info("Loading avatar video engine...") log.info("Video engine configured (models load on first avatar upload).")
cfg = VideoConfig.from_dict(video_cfg_raw) cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg) self.video_engine = VideoEngine(cfg)
self.video_engine.load_models() # load_models() is intentionally deferred: Wan2.2 + MuseTalk consume
log.info("Avatar video engine loaded (mode=%s).", cfg.mode) # ~6.5 GB VRAM at idle, which causes WDDM preemption latency on the
# Windows host even with no connected clients. Models are loaded on
# demand when set_avatar() is first called.
def create_vad(self) -> StreamingVAD: def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session.""" """Create a new StreamingVAD instance for a client session."""
+49 -22
View File
@@ -238,36 +238,63 @@ class ConversationSession:
tts_thread.join(timeout=2.0) tts_thread.join(timeout=2.0)
# Video mode: render the speaking clip now that TTS is done. # Video mode: stream speaking clips as they're generated (one per audio segment).
if use_video and audio_buffer and not self.cancel_event.is_set(): if use_video and audio_buffer and not self.cancel_event.is_set():
try: try:
full_audio = np.concatenate(audio_buffer).astype(np.float32) full_audio = np.concatenate(audio_buffer).astype(np.float32)
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000) sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
log.info( log.info(
"Video: rendering speaking clip (audio=%ds, mode=%s)", "Video: rendering speaking clips (audio=%.1fs, mode=%s)",
int(len(full_audio) / sample_rate), video_engine.cfg.mode, len(full_audio) / sample_rate, video_engine.cfg.mode,
) )
mp4_bytes = await asyncio.to_thread(
video_engine.generate_speaking_clip, clip_queue: queue.Queue = queue.Queue()
full_audio,
sample_rate, def _video_worker():
response, try:
) for clip_data in video_engine.generate_speaking_clips_streaming(
if self.cancel_event.is_set(): full_audio, sample_rate, response
log.info("Video clip discarded (cancelled during render).") ):
else: if self.cancel_event.is_set():
duration_ms = int(len(full_audio) / sample_rate * 1000) break
await self.send_json({ clip_queue.put(clip_data)
"type": "speaking_clip", except Exception:
"chunk_id": 0, log.exception("Video clip generation failed")
"duration_ms": duration_ms, finally:
"text": response, clip_queue.put(_SENTINEL)
"size_bytes": len(mp4_bytes),
}) video_thread = threading.Thread(target=_video_worker, daemon=True)
await self.send_bytes(mp4_bytes) video_thread.start()
is_first_clip = True
while not self.cancel_event.is_set():
try:
item = await asyncio.to_thread(clip_queue.get, timeout=120.0)
except Exception:
log.warning("Timed out waiting for video clip.")
break
if item is _SENTINEL:
break
if self.cancel_event.is_set():
break
mp4_bytes, duration_ms = item
try:
await self.send_json({
"type": "speaking_clip",
"chunk_id": 0,
"duration_ms": duration_ms,
"text": response if is_first_clip else "",
"size_bytes": len(mp4_bytes),
})
await self.send_bytes(mp4_bytes)
is_first_clip = False
except Exception:
log.warning("Failed to send video clip, client disconnected.")
self.cancel_event.set()
break
except Exception: except Exception:
log.exception("Video speaking-clip render failed; falling back silently.") log.exception("Video speaking-clip render failed; falling back silently.")
# Best-effort: tell the client nothing was spoken visually.
try: try:
await self.send_json({ await self.send_json({
"type": "response_text", "type": "response_text",
+79 -9
View File
@@ -11,6 +11,7 @@ from __future__ import annotations
import logging import logging
import threading import threading
from collections.abc import Iterator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal
@@ -287,9 +288,12 @@ class VideoEngine:
- Library mode: also pre-generate ``library.base_clip_count`` - Library mode: also pre-generate ``library.base_clip_count``
speaking base clips. speaking base clips.
- Reflective mode: idle loop only. - Reflective mode: idle loop only.
Lazily calls load_models() on first invocation so that Wan2.2's VRAM
footprint doesn't exist until video is actually used.
""" """
if self._wan22 is None: if self._wan22 is None:
raise RuntimeError("set_avatar called before load_models()") self.load_models()
with self._lock: with self._lock:
log.info("Setting avatar: %s", image_path) log.info("Setting avatar: %s", image_path)
@@ -383,8 +387,11 @@ class VideoEngine:
def _pick_library_frames( def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray: ) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped or looped """Round-robin pick from the pre-baked library, clipped to the segment duration.
to roughly the audio's duration so there's no long freeze frame."""
Does not loop frames — callers that need longer coverage should split
the audio into segments and call this once per segment.
"""
if not self.speaking_base_frames: if not self.speaking_base_frames:
raise RuntimeError( raise RuntimeError(
"Library mode has no pre-baked base clips. " "Library mode has no pre-baked base clips. "
@@ -398,12 +405,75 @@ class VideoEngine:
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps)) target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0: if target_frames <= 0:
return frames return frames
if target_frames <= len(frames): return frames[:min(target_frames, len(frames))]
return frames[:target_frames]
# Loop (with a mirror tail to soften the seam) to cover longer audio. def generate_speaking_clips_streaming(
loops = target_frames // len(frames) + 1 self,
extended = np.concatenate([frames] * loops, axis=0) audio_f32: np.ndarray,
return extended[:target_frames] sample_rate: int,
reply_text: str,
) -> Iterator[tuple[bytes, int]]:
"""Generate one MP4 per clip-length audio segment, yielding each when ready.
Splits ``audio_f32`` into segments of ``reflective_clip_seconds`` (or
``library_base_clip_seconds`` for library mode) and generates + lip-syncs
one clip per segment. Yields ``(mp4_bytes, duration_ms)`` tuples so the
caller can stream each clip to the client as soon as it's ready rather
than waiting for the full response.
"""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clips_streaming: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
if len(audio_f32) == 0:
return
clip_sec = (
self.cfg.library_base_clip_seconds
if self.cfg.mode == "library"
else self.cfg.reflective_clip_seconds
)
clip_samples = int(clip_sec * sample_rate)
segments = [
audio_f32[i : i + clip_samples]
for i in range(0, len(audio_f32), clip_samples)
]
for seg_audio in segments:
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(seg_audio, sample_rate)
else:
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt (clip segment): %s", prompt[:80])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None,
)
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
from server.video_models.muxer import frames_and_audio_to_mp4
mp4_bytes = frames_and_audio_to_mp4(
frames=synced_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
duration_ms = int(len(seg_audio) / sample_rate * 1000)
yield mp4_bytes, duration_ms
def _derive_prompt(self, reply_text: str) -> str: def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode. """Template-based prompt builder for reflective mode.
+52 -31
View File
@@ -24,6 +24,8 @@ let videoModeName = "off"; // "off" | "library" | "reflective"
let idleClipUrl = null; // URL string (server-served) or null let idleClipUrl = null; // URL string (server-served) or null
let pendingSpeakingClipMeta = null; // {chunk_id, duration_ms, text} waiting for MP4 binary let pendingSpeakingClipMeta = null; // {chunk_id, duration_ms, text} waiting for MP4 binary
let currentSpeakingClipBlobUrl = null; let currentSpeakingClipBlobUrl = null;
let speakingClipQueue = []; // [{blobUrl, meta}] clips waiting to play
let currentClipGeneration = 0; // incremented each clip start; guards stale onended handlers
const chatArea = document.getElementById("chat-area"); const chatArea = document.getElementById("chat-area");
const statusBadge = document.getElementById("status-badge"); const statusBadge = document.getElementById("status-badge");
@@ -131,67 +133,86 @@ function refreshStage() {
if (videoModeEnabled && idleClipUrl) { if (videoModeEnabled && idleClipUrl) {
stageEl.classList.add("active"); stageEl.classList.add("active");
if (avatarVideo.src !== location.origin + idleClipUrl) { if (avatarVideo.src !== location.origin + idleClipUrl) {
avatarVideo.src = idleClipUrl; _returnToIdle();
avatarVideo.loop = true;
avatarVideo.muted = true;
avatarVideo.play().catch(() => {});
} }
} else { } else {
stageEl.classList.remove("active"); stageEl.classList.remove("active");
} }
} }
function _returnToIdle() {
if (!idleClipUrl) return;
avatarVideo.onended = null;
avatarVideo.loop = false;
avatarVideo.muted = true;
avatarVideo.src = idleClipUrl;
avatarVideo.play().catch(() => {});
}
function playSpeakingClip(arrayBuffer, meta) { function playSpeakingClip(arrayBuffer, meta) {
// Replace the idle loop with the speaking clip.
stopSpeakingClip();
const blob = new Blob([arrayBuffer], { type: "video/mp4" }); const blob = new Blob([arrayBuffer], { type: "video/mp4" });
currentSpeakingClipBlobUrl = URL.createObjectURL(blob); const blobUrl = URL.createObjectURL(blob);
if (currentSpeakingClipBlobUrl !== null) {
// A clip is already playing — queue this one.
speakingClipQueue.push({ blobUrl, meta });
} else {
_startSpeakingClip(blobUrl, meta);
}
}
function _startSpeakingClip(blobUrl, meta) {
const gen = ++currentClipGeneration;
if (currentSpeakingClipBlobUrl) {
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
}
currentSpeakingClipBlobUrl = blobUrl;
avatarVideo.loop = false; avatarVideo.loop = false;
avatarVideo.muted = false; avatarVideo.muted = false;
avatarVideo.src = currentSpeakingClipBlobUrl; avatarVideo.src = blobUrl;
// Show the full reply text now — the MP4 plays it in one shot so there's
// no per-chunk sync to do.
if (meta && meta.text) { if (meta && meta.text) {
appendAssistantText(meta.text); appendAssistantText(meta.text);
} }
isPlaying = true; isPlaying = true;
avatarVideo.onended = () => { avatarVideo.onended = () => {
isPlaying = false; if (currentClipGeneration !== gen) return; // stale handler from a replaced clip
finalizeAssistantMessage(false); URL.revokeObjectURL(currentSpeakingClipBlobUrl);
// Return to idle loop. currentSpeakingClipBlobUrl = null;
if (idleClipUrl) {
avatarVideo.loop = true; const next = speakingClipQueue.shift();
avatarVideo.muted = true; if (next) {
avatarVideo.src = idleClipUrl; _startSpeakingClip(next.blobUrl, next.meta);
avatarVideo.play().catch(() => {}); } else {
} isPlaying = false;
if (currentSpeakingClipBlobUrl) { finalizeAssistantMessage(false);
URL.revokeObjectURL(currentSpeakingClipBlobUrl); _returnToIdle();
currentSpeakingClipBlobUrl = null;
} }
}; };
avatarVideo.play().catch((e) => { avatarVideo.play().catch((e) => {
console.error("speaking clip play failed:", e); console.error("speaking clip play failed:", e);
}); });
} }
function stopSpeakingClip() { function stopSpeakingClip() {
// Discard any queued clips.
for (const { blobUrl } of speakingClipQueue) {
URL.revokeObjectURL(blobUrl);
}
speakingClipQueue = [];
currentClipGeneration++; // invalidate any in-flight onended handlers
if (!currentSpeakingClipBlobUrl) return; if (!currentSpeakingClipBlobUrl) return;
try { try { avatarVideo.pause(); } catch (_) {}
avatarVideo.pause(); avatarVideo.onended = null;
} catch (_) {}
URL.revokeObjectURL(currentSpeakingClipBlobUrl); URL.revokeObjectURL(currentSpeakingClipBlobUrl);
currentSpeakingClipBlobUrl = null; currentSpeakingClipBlobUrl = null;
if (idleClipUrl) {
avatarVideo.loop = true;
avatarVideo.muted = true;
avatarVideo.src = idleClipUrl;
avatarVideo.play().catch(() => {});
}
isPlaying = false; isPlaying = false;
_returnToIdle();
} }
async function uploadAvatar() { async function uploadAvatar() {