131 lines
4.6 KiB
Python
131 lines
4.6 KiB
Python
import logging
|
|
import torch
|
|
|
|
from server.vad import StreamingVAD
|
|
from server.asr import ASREngine
|
|
from server.llm import LLMEngine
|
|
from server.tts import TTSEngine
|
|
from server.video import VideoConfig, VideoEngine
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_device():
|
|
"""Get the best available device (CUDA if available and working, otherwise CPU)."""
|
|
if torch.cuda.is_available():
|
|
try:
|
|
# Test CUDA availability
|
|
torch.zeros(1, device="cuda:0")
|
|
log.info("Using CUDA device")
|
|
return "cuda:0"
|
|
except RuntimeError as e:
|
|
log.warning(f"CUDA available but error occurred: {e}. Falling back to CPU.")
|
|
log.info("Using CPU device")
|
|
return "cpu"
|
|
|
|
|
|
class ModelManager:
|
|
"""Loads and holds all models. Initialized once at server startup."""
|
|
|
|
def __init__(self):
|
|
self.vad_model = None
|
|
self.asr_engine: ASREngine | None = None
|
|
self.llm_engine: LLMEngine | None = None
|
|
self.tts_engine: TTSEngine | None = None
|
|
self.video_engine: VideoEngine | None = None
|
|
|
|
def load_all(self):
|
|
"""Load all models sequentially. Call from the main process."""
|
|
self._load_vad()
|
|
self._load_asr()
|
|
self._load_llm()
|
|
self._load_tts()
|
|
self._load_video()
|
|
log.info("All models loaded successfully.")
|
|
|
|
def _load_vad(self):
|
|
log.info("Loading Silero VAD (ONNX)...")
|
|
from huggingface_hub import hf_hub_download
|
|
from server.vad import SileroVADOnnx
|
|
|
|
model_path = hf_hub_download(
|
|
repo_id="onnx-community/silero-vad", filename="onnx/model.onnx"
|
|
)
|
|
self.vad_model = SileroVADOnnx(model_path)
|
|
log.info("Silero VAD loaded (ONNX, CPU).")
|
|
|
|
def _load_asr(self):
|
|
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
|
|
from qwen_asr import Qwen3ASRModel
|
|
|
|
device = get_device()
|
|
asr_model = Qwen3ASRModel.from_pretrained(
|
|
"Qwen/Qwen3-ASR-0.6B",
|
|
dtype=torch.bfloat16,
|
|
device_map=device,
|
|
max_new_tokens=4096,
|
|
)
|
|
self.asr_engine = ASREngine(asr_model)
|
|
log.info("Qwen3-ASR-0.6B loaded.")
|
|
|
|
def _load_llm(self):
|
|
from server.config import config
|
|
|
|
llm_config = config.get("llm", {})
|
|
backend = llm_config.get("backend", "local")
|
|
system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.")
|
|
|
|
if backend == "lmstudio":
|
|
from server.llm import LMStudioEngine
|
|
|
|
lms = llm_config.get("lmstudio", {})
|
|
url = lms.get("url", "http://host.docker.internal:1234")
|
|
model = lms.get("model", "") or ""
|
|
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
|
|
self.llm_engine = LMStudioEngine(url, model, system_prompt)
|
|
else:
|
|
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
# model_name = "Qwen/Qwen3.5-0.8B"
|
|
model_name = "dphn/Dolphin-X1-8B-FP8"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
device = get_device()
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
device_map=device,
|
|
)
|
|
max_cache_tokens = llm_config.get("max_cache_tokens", 4096)
|
|
self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens)
|
|
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
|
|
|
def _load_tts(self):
|
|
log.info("Loading Kokoro TTS...")
|
|
self.tts_engine = TTSEngine()
|
|
log.info("Kokoro TTS loaded.")
|
|
|
|
def _load_video(self):
|
|
"""Load the avatar video stack iff config.video.enabled is true.
|
|
|
|
Leaves ``video_engine`` as None when disabled so existing voice flow
|
|
is untouched. Later phases replace this stub with actual Wan2.2 +
|
|
MuseTalk loading inside ``VideoEngine``.
|
|
"""
|
|
from server.config import config
|
|
|
|
video_cfg_raw = config.get("video", {}) or {}
|
|
if not video_cfg_raw.get("enabled", False):
|
|
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
|
|
return
|
|
|
|
log.info("Loading avatar video engine...")
|
|
cfg = VideoConfig.from_dict(video_cfg_raw)
|
|
self.video_engine = VideoEngine(cfg)
|
|
if cfg.loras:
|
|
self.video_engine.load_loras(cfg.loras)
|
|
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
|
|
|
|
def create_vad(self) -> StreamingVAD:
|
|
"""Create a new StreamingVAD instance for a client session."""
|
|
return StreamingVAD(self.vad_model)
|