add lm studio option
This commit is contained in:
@@ -0,0 +1,8 @@
|
|||||||
|
# LLM backend: "local" or "lmstudio"
|
||||||
|
llm:
|
||||||
|
backend: local # change to "lmstudio" to use LM Studio instead
|
||||||
|
|
||||||
|
# Settings used only when backend = "lmstudio"
|
||||||
|
lmstudio:
|
||||||
|
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
|
||||||
|
model: "" # leave empty to use whatever model LM Studio has loaded
|
||||||
@@ -6,6 +6,8 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
# Cache models on the host so they survive container rebuilds
|
# Cache models on the host so they survive container rebuilds
|
||||||
- huggingface-cache:/cache/huggingface
|
- huggingface-cache:/cache/huggingface
|
||||||
|
# Mount config so you can edit backend settings without rebuilding the image
|
||||||
|
- ./config.yml:/app/config.yml:ro
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
reservations:
|
reservations:
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ numpy
|
|||||||
soundfile
|
soundfile
|
||||||
scipy
|
scipy
|
||||||
python-multipart
|
python-multipart
|
||||||
|
pyyaml
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
import pathlib
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
_CONFIG_PATH = pathlib.Path(__file__).parent.parent / "config.yml"
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> dict:
|
||||||
|
with open(_CONFIG_PATH) as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
@@ -81,3 +81,66 @@ class LLMEngine:
|
|||||||
|
|
||||||
if remainder:
|
if remainder:
|
||||||
yield remainder
|
yield remainder
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"You are a helpful voice assistant. Keep your responses concise and natural "
|
||||||
|
"for spoken conversation. Respond in 1-3 short sentences. "
|
||||||
|
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
||||||
|
"formatting that doesn't work in speech."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMStudioEngine:
|
||||||
|
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, model: str):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
payload_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||||
|
payload_messages.extend(messages)
|
||||||
|
|
||||||
|
body: dict = {
|
||||||
|
"messages": payload_messages,
|
||||||
|
"max_tokens": max_new_tokens,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
if self.model:
|
||||||
|
body["model"] = self.model
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=body,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
response = resp.json()["choices"][0]["message"]["content"].strip()
|
||||||
|
log.info(f"LM Studio response: {response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def generate_sentences(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
cancel_event: threading.Event | None = None,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
response = await asyncio.to_thread(self.generate, messages)
|
||||||
|
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
sentences, remainder = split_sentences(response)
|
||||||
|
for sentence in sentences:
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
return
|
||||||
|
yield sentence
|
||||||
|
|
||||||
|
if remainder:
|
||||||
|
yield remainder
|
||||||
|
|||||||
+25
-12
@@ -46,7 +46,7 @@ class ModelManager:
|
|||||||
from server.vad import SileroVADOnnx
|
from server.vad import SileroVADOnnx
|
||||||
|
|
||||||
model_path = hf_hub_download(
|
model_path = hf_hub_download(
|
||||||
repo_id="onnx-community/silero-vad", filename="silero_vad.onnx"
|
repo_id="onnx-community/silero-vad", filename="onnx/model.onnx"
|
||||||
)
|
)
|
||||||
self.vad_model = SileroVADOnnx(model_path)
|
self.vad_model = SileroVADOnnx(model_path)
|
||||||
log.info("Silero VAD loaded (ONNX, CPU).")
|
log.info("Silero VAD loaded (ONNX, CPU).")
|
||||||
@@ -66,19 +66,32 @@ class ModelManager:
|
|||||||
log.info("Qwen3-ASR-0.6B loaded.")
|
log.info("Qwen3-ASR-0.6B loaded.")
|
||||||
|
|
||||||
def _load_llm(self):
|
def _load_llm(self):
|
||||||
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
from server.config import config
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
model_name = "Qwen/Qwen3.5-0.8B"
|
backend = config.get("llm", {}).get("backend", "local")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
if backend == "lmstudio":
|
||||||
device = get_device()
|
from server.llm import LMStudioEngine
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
lms = config.get("llm", {}).get("lmstudio", {})
|
||||||
device_map=device,
|
url = lms.get("url", "http://host.docker.internal:1234")
|
||||||
)
|
model = lms.get("model", "") or ""
|
||||||
self.llm_engine = LLMEngine(model, tokenizer)
|
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
|
||||||
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
self.llm_engine = LMStudioEngine(url, model)
|
||||||
|
else:
|
||||||
|
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen3.5-0.8B"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
device = get_device()
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
device_map=device,
|
||||||
|
)
|
||||||
|
self.llm_engine = LLMEngine(model, tokenizer)
|
||||||
|
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
||||||
|
|
||||||
def _load_tts(self):
|
def _load_tts(self):
|
||||||
log.info("Loading Kokoro TTS...")
|
log.info("Loading Kokoro TTS...")
|
||||||
|
|||||||
Reference in New Issue
Block a user