initial commit
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ASREngine:
|
||||
"""Wraps Qwen3-ASR for speech-to-text transcription."""
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def transcribe(self, audio_16k: np.ndarray) -> str:
|
||||
"""Transcribe a complete utterance.
|
||||
|
||||
Args:
|
||||
audio_16k: Float32 numpy array at 16kHz sample rate.
|
||||
|
||||
Returns:
|
||||
Transcribed text string.
|
||||
"""
|
||||
results = self.model.transcribe(
|
||||
audio=(audio_16k, 16000),
|
||||
language=None, # auto-detect
|
||||
)
|
||||
if results and results[0].text:
|
||||
return results[0].text.strip()
|
||||
return ""
|
||||
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
from scipy.signal import resample_poly
|
||||
from math import gcd
|
||||
|
||||
|
||||
def pcm_bytes_to_float32(pcm_bytes: bytes, dtype=np.int16) -> np.ndarray:
|
||||
"""Convert raw PCM bytes (16-bit signed int) to float32 in [-1, 1]."""
|
||||
audio = np.frombuffer(pcm_bytes, dtype=dtype)
|
||||
return audio.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def float32_to_pcm_bytes(audio) -> bytes:
|
||||
"""Convert float32 audio in [-1, 1] to 16-bit PCM bytes.
|
||||
|
||||
Accepts numpy arrays or PyTorch tensors.
|
||||
"""
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = audio.detach().cpu().numpy()
|
||||
clamped = np.clip(audio, -1.0, 1.0)
|
||||
return (clamped * 32767).astype(np.int16).tobytes()
|
||||
|
||||
|
||||
def resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
|
||||
"""Resample audio from orig_sr to target_sr using polyphase filtering."""
|
||||
if orig_sr == target_sr:
|
||||
return audio
|
||||
divisor = gcd(orig_sr, target_sr)
|
||||
up = target_sr // divisor
|
||||
down = orig_sr // divisor
|
||||
return resample_poly(audio, up, down).astype(audio.dtype)
|
||||
|
||||
|
||||
def split_sentences(text: str) -> tuple[list[str], str]:
|
||||
"""Split text into completed sentences and a remaining buffer.
|
||||
|
||||
Returns (sentences, remaining_buffer).
|
||||
Splits on sentence-ending punctuation followed by whitespace.
|
||||
"""
|
||||
sentences = []
|
||||
buffer = text
|
||||
terminators = ".!?"
|
||||
|
||||
i = 0
|
||||
start = 0
|
||||
while i < len(buffer):
|
||||
if buffer[i] in terminators:
|
||||
# Look ahead for whitespace or end of string
|
||||
end = i + 1
|
||||
while end < len(buffer) and buffer[end] in terminators:
|
||||
end += 1
|
||||
if end >= len(buffer) or buffer[end] == " " or buffer[end] == "\n":
|
||||
sentence = buffer[start:end].strip()
|
||||
if sentence:
|
||||
sentences.append(sentence)
|
||||
start = end
|
||||
i = end
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
remaining = buffer[start:].strip()
|
||||
return sentences, remaining
|
||||
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from server.audio_utils import split_sentences
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Wraps Qwen3 for conversation generation."""
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful voice assistant. Keep your responses concise and natural "
|
||||
"for spoken conversation. Respond in 1-3 short sentences. "
|
||||
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
||||
"formatting that doesn't work in speech."
|
||||
)
|
||||
|
||||
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def _build_inputs(self, messages: list[dict]):
|
||||
"""Build input token ids using the model's chat template."""
|
||||
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}]
|
||||
for msg in messages:
|
||||
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
chat_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||
|
||||
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
||||
"""Generate a complete response (blocking)."""
|
||||
inputs = self._build_inputs(messages)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens (skip prompt)
|
||||
new_ids = output_ids[0][input_len:]
|
||||
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
||||
log.info(f"LLM response: {response}")
|
||||
return response
|
||||
|
||||
async def generate_sentences(
|
||||
self,
|
||||
messages: list[dict],
|
||||
cancel_event: threading.Event | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
||||
import asyncio
|
||||
|
||||
response = await asyncio.to_thread(self.generate, messages)
|
||||
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
|
||||
# Split into sentences and yield each
|
||||
sentences, remainder = split_sentences(response)
|
||||
for sentence in sentences:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
yield sentence
|
||||
|
||||
if remainder:
|
||||
yield remainder
|
||||
@@ -0,0 +1,87 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.params import Form
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from server.audio_utils import pcm_bytes_to_float32
|
||||
from server.models import ModelManager
|
||||
from server.pipeline import ConversationSession
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
||||
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||
|
||||
model_mgr = ModelManager()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
log.info("Starting model loading...")
|
||||
model_mgr.load_all()
|
||||
log.info("Server ready.")
|
||||
yield
|
||||
log.info("Shutting down.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index():
|
||||
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.post("/api/set-voice")
|
||||
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
||||
"""Change the TTS voice."""
|
||||
model_mgr.tts_engine.set_voice(voice, lang)
|
||||
return {"status": "ok", "voice": voice}
|
||||
|
||||
|
||||
@app.websocket("/ws/chat")
|
||||
async def websocket_chat(ws: WebSocket):
|
||||
await ws.accept()
|
||||
log.info("WebSocket client connected.")
|
||||
|
||||
async def send_json(data: dict):
|
||||
await ws.send_text(json.dumps(data))
|
||||
|
||||
async def send_bytes(data: bytes):
|
||||
await ws.send_bytes(data)
|
||||
|
||||
session = ConversationSession(model_mgr, send_json, send_bytes)
|
||||
await session.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await ws.receive()
|
||||
|
||||
if "bytes" in message:
|
||||
pcm_data = message["bytes"]
|
||||
chunk = pcm_bytes_to_float32(pcm_data)
|
||||
await session.handle_audio_chunk(chunk)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
msg = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if msg.get("type") == "interrupt":
|
||||
await session.interrupt()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
log.info("WebSocket client disconnected.")
|
||||
except Exception:
|
||||
log.exception("WebSocket error")
|
||||
finally:
|
||||
await session.stop()
|
||||
@@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from server.vad import StreamingVAD
|
||||
from server.asr import ASREngine
|
||||
from server.llm import LLMEngine
|
||||
from server.tts import TTSEngine
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Loads and holds all models. Initialized once at server startup."""
|
||||
|
||||
def __init__(self):
|
||||
self.vad_model = None
|
||||
self.asr_engine: ASREngine | None = None
|
||||
self.llm_engine: LLMEngine | None = None
|
||||
self.tts_engine: TTSEngine | None = None
|
||||
|
||||
def load_all(self):
|
||||
"""Load all models sequentially. Call from the main process."""
|
||||
self._load_vad()
|
||||
self._load_asr()
|
||||
self._load_llm()
|
||||
self._load_tts()
|
||||
log.info("All models loaded successfully.")
|
||||
|
||||
def _load_vad(self):
|
||||
log.info("Loading Silero VAD...")
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
self.vad_model = load_silero_vad()
|
||||
log.info("Silero VAD loaded (CPU).")
|
||||
|
||||
def _load_asr(self):
|
||||
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
asr_model = Qwen3ASRModel.from_pretrained(
|
||||
"Qwen/Qwen3-ASR-0.6B",
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
self.asr_engine = ASREngine(asr_model)
|
||||
log.info("Qwen3-ASR-0.6B loaded.")
|
||||
|
||||
def _load_llm(self):
|
||||
log.info("Loading Qwen3-0.6B-Instruct...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
self.llm_engine = LLMEngine(model, tokenizer)
|
||||
log.info("Qwen3-0.6B-Instruct loaded.")
|
||||
|
||||
def _load_tts(self):
|
||||
log.info("Loading Kokoro TTS...")
|
||||
self.tts_engine = TTSEngine()
|
||||
log.info("Kokoro TTS loaded.")
|
||||
|
||||
def create_vad(self) -> StreamingVAD:
|
||||
"""Create a new StreamingVAD instance for a client session."""
|
||||
return StreamingVAD(self.vad_model)
|
||||
@@ -0,0 +1,164 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from server.audio_utils import float32_to_pcm_bytes
|
||||
from server.models import ModelManager
|
||||
from server.vad import StreamingVAD
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_SENTINEL = None
|
||||
|
||||
|
||||
class ConversationSession:
|
||||
"""Manages a single client's voice conversation pipeline.
|
||||
|
||||
Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support.
|
||||
"""
|
||||
|
||||
def __init__(self, models: ModelManager, send_json, send_bytes):
|
||||
self.models = models
|
||||
self.send_json = send_json
|
||||
self.send_bytes = send_bytes
|
||||
|
||||
self.vad: StreamingVAD = models.create_vad()
|
||||
self.conversation_history: list[dict] = []
|
||||
self.cancel_event = threading.Event()
|
||||
self.is_responding = False
|
||||
self._response_task: asyncio.Task | None = None
|
||||
|
||||
async def start(self):
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
|
||||
async def stop(self):
|
||||
self.cancel_event.set()
|
||||
if self._response_task and not self._response_task.done():
|
||||
self._response_task.cancel()
|
||||
|
||||
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
||||
utterance = self.vad.process_chunk(chunk_16k)
|
||||
|
||||
if utterance is not None:
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
# Launch response pipeline as a background task so we don't block receives
|
||||
self._response_task = asyncio.create_task(self._process_utterance(utterance))
|
||||
elif self.vad.is_speaking and self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def interrupt(self):
|
||||
"""Public interrupt method for WebSocket text messages."""
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def _interrupt(self):
|
||||
log.info("Barge-in: cancelling response.")
|
||||
self.cancel_event.set()
|
||||
self.is_responding = False
|
||||
# Tell client to stop audio immediately
|
||||
try:
|
||||
await self.send_json({"type": "interrupt"})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _process_utterance(self, audio_16k: np.ndarray):
|
||||
"""Full pipeline: ASR -> LLM -> TTS streaming."""
|
||||
self.is_responding = True
|
||||
self.cancel_event.clear()
|
||||
|
||||
# ASR
|
||||
await self.send_json({"type": "status", "state": "thinking"})
|
||||
text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k)
|
||||
|
||||
if not text:
|
||||
log.info("ASR returned empty text, resuming listening.")
|
||||
self.is_responding = False
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
return
|
||||
|
||||
await self.send_json({"type": "transcript", "text": text, "final": True})
|
||||
log.info(f"User: {text}")
|
||||
self.conversation_history.append({"role": "user", "content": text})
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
self.is_responding = False
|
||||
return
|
||||
|
||||
# LLM
|
||||
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
||||
+ str([m['content'][:50] for m in self.conversation_history]))
|
||||
response = await asyncio.to_thread(
|
||||
self.models.llm_engine.generate, self.conversation_history
|
||||
)
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
self.is_responding = False
|
||||
return
|
||||
|
||||
# TTS - stream chunks with per-sentence text
|
||||
await self.send_json({"type": "status", "state": "speaking"})
|
||||
chunk_queue = queue.Queue()
|
||||
|
||||
def _tts_worker():
|
||||
try:
|
||||
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
||||
response, voice=self.models.tts_engine.voice
|
||||
):
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
if audio is not None and len(audio) > 0:
|
||||
chunk_queue.put((graphemes, audio))
|
||||
except Exception:
|
||||
log.exception("TTS generation error")
|
||||
finally:
|
||||
chunk_queue.put(_SENTINEL)
|
||||
|
||||
tts_thread = threading.Thread(target=_tts_worker, daemon=True)
|
||||
tts_thread.start()
|
||||
|
||||
spoken_text = ""
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
|
||||
sentence_text, audio = item
|
||||
spoken_text += sentence_text
|
||||
|
||||
await self.send_json({"type": "response_text", "text": sentence_text, "final": False})
|
||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||
try:
|
||||
await self.send_bytes(pcm_bytes)
|
||||
except Exception:
|
||||
log.warning("Failed to send audio, client disconnected.")
|
||||
self.cancel_event.set()
|
||||
break
|
||||
|
||||
tts_thread.join(timeout=2.0)
|
||||
|
||||
# Save only what was actually spoken
|
||||
if spoken_text.strip():
|
||||
self.conversation_history.append(
|
||||
{"role": "assistant", "content": spoken_text.strip()}
|
||||
)
|
||||
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
||||
self.conversation_history.pop()
|
||||
|
||||
if not self.cancel_event.is_set():
|
||||
await self.send_json({"type": "response_text", "text": "", "final": True})
|
||||
|
||||
self.is_responding = False
|
||||
try:
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_VOICE = "af_heart"
|
||||
DEFAULT_LANG = "a" # American English
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
"""Wraps Kokoro TTS for fast streaming text-to-speech."""
|
||||
|
||||
def __init__(self):
|
||||
from kokoro import KPipeline
|
||||
|
||||
self.pipeline = KPipeline(lang_code=DEFAULT_LANG)
|
||||
self.voice = DEFAULT_VOICE
|
||||
self.sample_rate = 24000
|
||||
|
||||
def set_voice(self, voice: str, lang_code: str = "a"):
|
||||
"""Change the voice."""
|
||||
from kokoro import KPipeline
|
||||
|
||||
self.voice = voice
|
||||
self.pipeline = KPipeline(lang_code=lang_code)
|
||||
log.info(f"Voice set to: {voice} (lang: {lang_code})")
|
||||
|
||||
def synthesize_stream(self, text: str) -> Iterator[np.ndarray]:
|
||||
"""Yield audio chunks as they are generated.
|
||||
|
||||
Each chunk is a float32 numpy array at self.sample_rate (24kHz).
|
||||
Kokoro internally splits text into sentences and yields per-sentence audio.
|
||||
"""
|
||||
for _gs, _ps, audio in self.pipeline(text, voice=self.voice):
|
||||
if audio is not None and len(audio) > 0:
|
||||
yield audio
|
||||
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class StreamingVAD:
|
||||
"""Wraps Silero VAD for streaming chunk-by-chunk speech detection."""
|
||||
|
||||
def __init__(self, model, threshold: float = 0.5, min_silence_ms: int = 400):
|
||||
from silero_vad import VADIterator
|
||||
|
||||
self.iterator = VADIterator(
|
||||
model,
|
||||
sampling_rate=16000,
|
||||
threshold=threshold,
|
||||
min_silence_duration_ms=min_silence_ms,
|
||||
)
|
||||
self.audio_buffer: list[np.ndarray] = []
|
||||
self.is_speaking = False
|
||||
|
||||
def process_chunk(self, chunk_16k: np.ndarray) -> np.ndarray | None:
|
||||
"""Feed a 512-sample chunk at 16kHz.
|
||||
|
||||
Returns the complete utterance as a numpy array when speech ends,
|
||||
or None if still accumulating.
|
||||
"""
|
||||
tensor = torch.from_numpy(chunk_16k).float()
|
||||
speech_dict = self.iterator(tensor, return_seconds=False)
|
||||
|
||||
if speech_dict:
|
||||
if "start" in speech_dict:
|
||||
self.is_speaking = True
|
||||
self.audio_buffer = []
|
||||
if "end" in speech_dict:
|
||||
self.is_speaking = False
|
||||
if self.audio_buffer:
|
||||
result = np.concatenate(self.audio_buffer)
|
||||
self.audio_buffer = []
|
||||
self.iterator.reset_states()
|
||||
return result
|
||||
self.iterator.reset_states()
|
||||
return None
|
||||
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(chunk_16k.copy())
|
||||
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""Reset VAD state for a new conversation turn."""
|
||||
self.audio_buffer = []
|
||||
self.is_speaking = False
|
||||
self.iterator.reset_states()
|
||||
Reference in New Issue
Block a user