Enhance video handling and performance optimizations

- Added environment variables to prevent CPU thread pools from busy-waiting.
- Deferred loading of video models until first use to reduce VRAM footprint.
- Implemented streaming of speaking clips for improved responsiveness.
- Introduced a queue for managing speaking clips to handle multiple requests smoothly.
- Updated video playback logic to ensure proper handling of clip generation.
This commit is contained in:
2026-04-24 00:36:18 -04:00
parent 129df7d1fa
commit 44a10667c2
7 changed files with 234 additions and 69 deletions
+29 -3
View File
@@ -80,11 +80,28 @@ class LLMEngine:
f"processing {input_len - cached_len} new tokens"
)
with torch.no_grad():
outputs = self.model.generate(
# Guard: if the cache claims to have seen >= input tokens, it's
# stale (can happen after barge-in races or tokenizer mismatches).
# An invalid cache causes an empty cache_position in transformers,
# which raises IndexError inside model.generate().
if past_kv is not None:
cache_seq_len = (
past_kv.get_seq_length()
if hasattr(past_kv, "get_seq_length")
else cached_len
)
if cache_seq_len >= input_len:
log.warning(
f"KV-cache stale (cache_seq={cache_seq_len} >= input={input_len}), discarding."
)
past_kv = None
cached_len = 0
def _do_generate(pkv):
return self.model.generate(
input_ids=input_ids,
attention_mask=inputs.get("attention_mask"),
past_key_values=past_kv,
past_key_values=pkv,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
@@ -94,6 +111,15 @@ class LLMEngine:
use_cache=True,
)
with torch.no_grad():
try:
outputs = _do_generate(past_kv)
except IndexError:
log.warning("KV-cache caused IndexError during generate; retrying without cache.")
past_kv = None
cached_len = 0
outputs = _do_generate(None)
# Decode only the generated tokens (skip prompt)
new_ids = outputs.sequences[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
+5 -3
View File
@@ -118,11 +118,13 @@ class ModelManager:
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
return
log.info("Loading avatar video engine...")
log.info("Video engine configured (models load on first avatar upload).")
cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg)
self.video_engine.load_models()
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
# load_models() is intentionally deferred: Wan2.2 + MuseTalk consume
# ~6.5 GB VRAM at idle, which causes WDDM preemption latency on the
# Windows host even with no connected clients. Models are loaded on
# demand when set_avatar() is first called.
def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session."""
+49 -22
View File
@@ -238,36 +238,63 @@ class ConversationSession:
tts_thread.join(timeout=2.0)
# Video mode: render the speaking clip now that TTS is done.
# Video mode: stream speaking clips as they're generated (one per audio segment).
if use_video and audio_buffer and not self.cancel_event.is_set():
try:
full_audio = np.concatenate(audio_buffer).astype(np.float32)
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
log.info(
"Video: rendering speaking clip (audio=%ds, mode=%s)",
int(len(full_audio) / sample_rate), video_engine.cfg.mode,
"Video: rendering speaking clips (audio=%.1fs, mode=%s)",
len(full_audio) / sample_rate, video_engine.cfg.mode,
)
mp4_bytes = await asyncio.to_thread(
video_engine.generate_speaking_clip,
full_audio,
sample_rate,
response,
)
if self.cancel_event.is_set():
log.info("Video clip discarded (cancelled during render).")
else:
duration_ms = int(len(full_audio) / sample_rate * 1000)
await self.send_json({
"type": "speaking_clip",
"chunk_id": 0,
"duration_ms": duration_ms,
"text": response,
"size_bytes": len(mp4_bytes),
})
await self.send_bytes(mp4_bytes)
clip_queue: queue.Queue = queue.Queue()
def _video_worker():
try:
for clip_data in video_engine.generate_speaking_clips_streaming(
full_audio, sample_rate, response
):
if self.cancel_event.is_set():
break
clip_queue.put(clip_data)
except Exception:
log.exception("Video clip generation failed")
finally:
clip_queue.put(_SENTINEL)
video_thread = threading.Thread(target=_video_worker, daemon=True)
video_thread.start()
is_first_clip = True
while not self.cancel_event.is_set():
try:
item = await asyncio.to_thread(clip_queue.get, timeout=120.0)
except Exception:
log.warning("Timed out waiting for video clip.")
break
if item is _SENTINEL:
break
if self.cancel_event.is_set():
break
mp4_bytes, duration_ms = item
try:
await self.send_json({
"type": "speaking_clip",
"chunk_id": 0,
"duration_ms": duration_ms,
"text": response if is_first_clip else "",
"size_bytes": len(mp4_bytes),
})
await self.send_bytes(mp4_bytes)
is_first_clip = False
except Exception:
log.warning("Failed to send video clip, client disconnected.")
self.cancel_event.set()
break
except Exception:
log.exception("Video speaking-clip render failed; falling back silently.")
# Best-effort: tell the client nothing was spoken visually.
try:
await self.send_json({
"type": "response_text",
+79 -9
View File
@@ -11,6 +11,7 @@ from __future__ import annotations
import logging
import threading
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Literal
@@ -287,9 +288,12 @@ class VideoEngine:
- Library mode: also pre-generate ``library.base_clip_count``
speaking base clips.
- Reflective mode: idle loop only.
Lazily calls load_models() on first invocation so that Wan2.2's VRAM
footprint doesn't exist until video is actually used.
"""
if self._wan22 is None:
raise RuntimeError("set_avatar called before load_models()")
self.load_models()
with self._lock:
log.info("Setting avatar: %s", image_path)
@@ -383,8 +387,11 @@ class VideoEngine:
def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped or looped
to roughly the audio's duration so there's no long freeze frame."""
"""Round-robin pick from the pre-baked library, clipped to the segment duration.
Does not loop frames — callers that need longer coverage should split
the audio into segments and call this once per segment.
"""
if not self.speaking_base_frames:
raise RuntimeError(
"Library mode has no pre-baked base clips. "
@@ -398,12 +405,75 @@ class VideoEngine:
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0:
return frames
if target_frames <= len(frames):
return frames[:target_frames]
# Loop (with a mirror tail to soften the seam) to cover longer audio.
loops = target_frames // len(frames) + 1
extended = np.concatenate([frames] * loops, axis=0)
return extended[:target_frames]
return frames[:min(target_frames, len(frames))]
def generate_speaking_clips_streaming(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> Iterator[tuple[bytes, int]]:
"""Generate one MP4 per clip-length audio segment, yielding each when ready.
Splits ``audio_f32`` into segments of ``reflective_clip_seconds`` (or
``library_base_clip_seconds`` for library mode) and generates + lip-syncs
one clip per segment. Yields ``(mp4_bytes, duration_ms)`` tuples so the
caller can stream each clip to the client as soon as it's ready rather
than waiting for the full response.
"""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clips_streaming: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
if len(audio_f32) == 0:
return
clip_sec = (
self.cfg.library_base_clip_seconds
if self.cfg.mode == "library"
else self.cfg.reflective_clip_seconds
)
clip_samples = int(clip_sec * sample_rate)
segments = [
audio_f32[i : i + clip_samples]
for i in range(0, len(audio_f32), clip_samples)
]
for seg_audio in segments:
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(seg_audio, sample_rate)
else:
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt (clip segment): %s", prompt[:80])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None,
)
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
from server.video_models.muxer import frames_and_audio_to_mp4
mp4_bytes = frames_and_audio_to_mp4(
frames=synced_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
duration_ms = int(len(seg_audio) / sample_rate * 1000)
yield mp4_bytes, duration_ms
def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode.