initial commit
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from server.vad import StreamingVAD
|
||||
from server.asr import ASREngine
|
||||
from server.llm import LLMEngine
|
||||
from server.tts import TTSEngine
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Loads and holds all models. Initialized once at server startup."""
|
||||
|
||||
def __init__(self):
|
||||
self.vad_model = None
|
||||
self.asr_engine: ASREngine | None = None
|
||||
self.llm_engine: LLMEngine | None = None
|
||||
self.tts_engine: TTSEngine | None = None
|
||||
|
||||
def load_all(self):
|
||||
"""Load all models sequentially. Call from the main process."""
|
||||
self._load_vad()
|
||||
self._load_asr()
|
||||
self._load_llm()
|
||||
self._load_tts()
|
||||
log.info("All models loaded successfully.")
|
||||
|
||||
def _load_vad(self):
|
||||
log.info("Loading Silero VAD...")
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
self.vad_model = load_silero_vad()
|
||||
log.info("Silero VAD loaded (CPU).")
|
||||
|
||||
def _load_asr(self):
|
||||
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
asr_model = Qwen3ASRModel.from_pretrained(
|
||||
"Qwen/Qwen3-ASR-0.6B",
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
self.asr_engine = ASREngine(asr_model)
|
||||
log.info("Qwen3-ASR-0.6B loaded.")
|
||||
|
||||
def _load_llm(self):
|
||||
log.info("Loading Qwen3-0.6B-Instruct...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
self.llm_engine = LLMEngine(model, tokenizer)
|
||||
log.info("Qwen3-0.6B-Instruct loaded.")
|
||||
|
||||
def _load_tts(self):
|
||||
log.info("Loading Kokoro TTS...")
|
||||
self.tts_engine = TTSEngine()
|
||||
log.info("Kokoro TTS loaded.")
|
||||
|
||||
def create_vad(self) -> StreamingVAD:
|
||||
"""Create a new StreamingVAD instance for a client session."""
|
||||
return StreamingVAD(self.vad_model)
|
||||
Reference in New Issue
Block a user