updates for docker take 1
This commit is contained in:
+22
-6
@@ -9,6 +9,20 @@ from server.tts import TTSEngine
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Get the best available device (CUDA if available and working, otherwise CPU)."""
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test CUDA availability
|
||||
torch.zeros(1, device="cuda:0")
|
||||
log.info("Using CUDA device")
|
||||
return "cuda:0"
|
||||
except RuntimeError as e:
|
||||
log.warning(f"CUDA available but error occurred: {e}. Falling back to CPU.")
|
||||
log.info("Using CPU device")
|
||||
return "cpu"
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Loads and holds all models. Initialized once at server startup."""
|
||||
|
||||
@@ -37,28 +51,30 @@ class ModelManager:
|
||||
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
device = get_device()
|
||||
asr_model = Qwen3ASRModel.from_pretrained(
|
||||
"Qwen/Qwen3-ASR-0.6B",
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
device_map=device,
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
self.asr_engine = ASREngine(asr_model)
|
||||
log.info("Qwen3-ASR-0.6B loaded.")
|
||||
|
||||
def _load_llm(self):
|
||||
log.info("Loading Qwen3-0.6B-Instruct...")
|
||||
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
model_name = "Qwen/Qwen3.5-0.8B"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
device = get_device()
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
device_map=device,
|
||||
)
|
||||
self.llm_engine = LLMEngine(model, tokenizer)
|
||||
log.info("Qwen3-0.6B-Instruct loaded.")
|
||||
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
||||
|
||||
def _load_tts(self):
|
||||
log.info("Loading Kokoro TTS...")
|
||||
|
||||
Reference in New Issue
Block a user