Files
live-voice-chat/server/models.py
T
2026-04-08 10:17:20 -04:00

104 lines
3.4 KiB
Python

import logging
import torch
from server.vad import StreamingVAD
from server.asr import ASREngine
from server.llm import LLMEngine
from server.tts import TTSEngine
log = logging.getLogger(__name__)
def get_device():
"""Get the best available device (CUDA if available and working, otherwise CPU)."""
if torch.cuda.is_available():
try:
# Test CUDA availability
torch.zeros(1, device="cuda:0")
log.info("Using CUDA device")
return "cuda:0"
except RuntimeError as e:
log.warning(f"CUDA available but error occurred: {e}. Falling back to CPU.")
log.info("Using CPU device")
return "cpu"
class ModelManager:
"""Loads and holds all models. Initialized once at server startup."""
def __init__(self):
self.vad_model = None
self.asr_engine: ASREngine | None = None
self.llm_engine: LLMEngine | None = None
self.tts_engine: TTSEngine | None = None
def load_all(self):
"""Load all models sequentially. Call from the main process."""
self._load_vad()
self._load_asr()
self._load_llm()
self._load_tts()
log.info("All models loaded successfully.")
def _load_vad(self):
log.info("Loading Silero VAD (ONNX)...")
from huggingface_hub import hf_hub_download
from server.vad import SileroVADOnnx
model_path = hf_hub_download(
repo_id="onnx-community/silero-vad", filename="onnx/model.onnx"
)
self.vad_model = SileroVADOnnx(model_path)
log.info("Silero VAD loaded (ONNX, CPU).")
def _load_asr(self):
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
from qwen_asr import Qwen3ASRModel
device = get_device()
asr_model = Qwen3ASRModel.from_pretrained(
"Qwen/Qwen3-ASR-0.6B",
dtype=torch.bfloat16,
device_map=device,
max_new_tokens=4096,
)
self.asr_engine = ASREngine(asr_model)
log.info("Qwen3-ASR-0.6B loaded.")
def _load_llm(self):
from server.config import config
backend = config.get("llm", {}).get("backend", "local")
if backend == "lmstudio":
from server.llm import LMStudioEngine
lms = config.get("llm", {}).get("lmstudio", {})
url = lms.get("url", "http://host.docker.internal:1234")
model = lms.get("model", "") or ""
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
self.llm_engine = LMStudioEngine(url, model)
else:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = get_device()
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
)
self.llm_engine = LLMEngine(model, tokenizer)
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
def _load_tts(self):
log.info("Loading Kokoro TTS...")
self.tts_engine = TTSEngine()
log.info("Kokoro TTS loaded.")
def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session."""
return StreamingVAD(self.vad_model)