Files
live-voice-chat/server/models.py
T
2026-04-12 04:11:52 -04:00

131 lines
4.6 KiB
Python

import logging
import torch
from server.vad import StreamingVAD
from server.asr import ASREngine
from server.llm import LLMEngine
from server.tts import TTSEngine
from server.video import VideoConfig, VideoEngine
log = logging.getLogger(__name__)
def get_device():
"""Get the best available device (CUDA if available and working, otherwise CPU)."""
if torch.cuda.is_available():
try:
# Test CUDA availability
torch.zeros(1, device="cuda:0")
log.info("Using CUDA device")
return "cuda:0"
except RuntimeError as e:
log.warning(f"CUDA available but error occurred: {e}. Falling back to CPU.")
log.info("Using CPU device")
return "cpu"
class ModelManager:
"""Loads and holds all models. Initialized once at server startup."""
def __init__(self):
self.vad_model = None
self.asr_engine: ASREngine | None = None
self.llm_engine: LLMEngine | None = None
self.tts_engine: TTSEngine | None = None
self.video_engine: VideoEngine | None = None
def load_all(self):
"""Load all models sequentially. Call from the main process."""
self._load_vad()
self._load_asr()
self._load_llm()
self._load_tts()
self._load_video()
log.info("All models loaded successfully.")
def _load_vad(self):
log.info("Loading Silero VAD (ONNX)...")
from huggingface_hub import hf_hub_download
from server.vad import SileroVADOnnx
model_path = hf_hub_download(
repo_id="onnx-community/silero-vad", filename="onnx/model.onnx"
)
self.vad_model = SileroVADOnnx(model_path)
log.info("Silero VAD loaded (ONNX, CPU).")
def _load_asr(self):
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
from qwen_asr import Qwen3ASRModel
device = get_device()
asr_model = Qwen3ASRModel.from_pretrained(
"Qwen/Qwen3-ASR-0.6B",
dtype=torch.bfloat16,
device_map=device,
max_new_tokens=4096,
)
self.asr_engine = ASREngine(asr_model)
log.info("Qwen3-ASR-0.6B loaded.")
def _load_llm(self):
from server.config import config
llm_config = config.get("llm", {})
backend = llm_config.get("backend", "local")
system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.")
if backend == "lmstudio":
from server.llm import LMStudioEngine
lms = llm_config.get("lmstudio", {})
url = lms.get("url", "http://host.docker.internal:1234")
model = lms.get("model", "") or ""
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
self.llm_engine = LMStudioEngine(url, model, system_prompt)
else:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "Qwen/Qwen3.5-0.8B"
model_name = "dphn/Dolphin-X1-8B-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = get_device()
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
)
max_cache_tokens = llm_config.get("max_cache_tokens", 4096)
self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens)
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
def _load_tts(self):
log.info("Loading Kokoro TTS...")
self.tts_engine = TTSEngine()
log.info("Kokoro TTS loaded.")
def _load_video(self):
"""Load the avatar video stack iff config.video.enabled is true.
Leaves ``video_engine`` as None when disabled so existing voice flow
is untouched. Later phases replace this stub with actual Wan2.2 +
MuseTalk loading inside ``VideoEngine``.
"""
from server.config import config
video_cfg_raw = config.get("video", {}) or {}
if not video_cfg_raw.get("enabled", False):
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
return
log.info("Loading avatar video engine...")
cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg)
if cfg.loras:
self.video_engine.load_loras(cfg.loras)
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session."""
return StreamingVAD(self.vad_model)