initial commit
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from server.audio_utils import split_sentences
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Wraps Qwen3 for conversation generation."""
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful voice assistant. Keep your responses concise and natural "
|
||||
"for spoken conversation. Respond in 1-3 short sentences. "
|
||||
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
||||
"formatting that doesn't work in speech."
|
||||
)
|
||||
|
||||
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def _build_inputs(self, messages: list[dict]):
|
||||
"""Build input token ids using the model's chat template."""
|
||||
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}]
|
||||
for msg in messages:
|
||||
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
chat_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||
|
||||
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
||||
"""Generate a complete response (blocking)."""
|
||||
inputs = self._build_inputs(messages)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens (skip prompt)
|
||||
new_ids = output_ids[0][input_len:]
|
||||
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
||||
log.info(f"LLM response: {response}")
|
||||
return response
|
||||
|
||||
async def generate_sentences(
|
||||
self,
|
||||
messages: list[dict],
|
||||
cancel_event: threading.Event | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
||||
import asyncio
|
||||
|
||||
response = await asyncio.to_thread(self.generate, messages)
|
||||
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
|
||||
# Split into sentences and yield each
|
||||
sentences, remainder = split_sentences(response)
|
||||
for sentence in sentences:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
yield sentence
|
||||
|
||||
if remainder:
|
||||
yield remainder
|
||||
Reference in New Issue
Block a user