initial commit
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
.venv
|
||||
__pycache__
|
||||
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
torch>=2.5.0
|
||||
transformers==4.57.6
|
||||
silero-vad>=5.1
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.30.0
|
||||
numpy
|
||||
soundfile
|
||||
scipy
|
||||
python-multipart
|
||||
@@ -0,0 +1,10 @@
|
||||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"server.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=False,
|
||||
log_level="info",
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ASREngine:
|
||||
"""Wraps Qwen3-ASR for speech-to-text transcription."""
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def transcribe(self, audio_16k: np.ndarray) -> str:
|
||||
"""Transcribe a complete utterance.
|
||||
|
||||
Args:
|
||||
audio_16k: Float32 numpy array at 16kHz sample rate.
|
||||
|
||||
Returns:
|
||||
Transcribed text string.
|
||||
"""
|
||||
results = self.model.transcribe(
|
||||
audio=(audio_16k, 16000),
|
||||
language=None, # auto-detect
|
||||
)
|
||||
if results and results[0].text:
|
||||
return results[0].text.strip()
|
||||
return ""
|
||||
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
from scipy.signal import resample_poly
|
||||
from math import gcd
|
||||
|
||||
|
||||
def pcm_bytes_to_float32(pcm_bytes: bytes, dtype=np.int16) -> np.ndarray:
|
||||
"""Convert raw PCM bytes (16-bit signed int) to float32 in [-1, 1]."""
|
||||
audio = np.frombuffer(pcm_bytes, dtype=dtype)
|
||||
return audio.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def float32_to_pcm_bytes(audio) -> bytes:
|
||||
"""Convert float32 audio in [-1, 1] to 16-bit PCM bytes.
|
||||
|
||||
Accepts numpy arrays or PyTorch tensors.
|
||||
"""
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = audio.detach().cpu().numpy()
|
||||
clamped = np.clip(audio, -1.0, 1.0)
|
||||
return (clamped * 32767).astype(np.int16).tobytes()
|
||||
|
||||
|
||||
def resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
|
||||
"""Resample audio from orig_sr to target_sr using polyphase filtering."""
|
||||
if orig_sr == target_sr:
|
||||
return audio
|
||||
divisor = gcd(orig_sr, target_sr)
|
||||
up = target_sr // divisor
|
||||
down = orig_sr // divisor
|
||||
return resample_poly(audio, up, down).astype(audio.dtype)
|
||||
|
||||
|
||||
def split_sentences(text: str) -> tuple[list[str], str]:
|
||||
"""Split text into completed sentences and a remaining buffer.
|
||||
|
||||
Returns (sentences, remaining_buffer).
|
||||
Splits on sentence-ending punctuation followed by whitespace.
|
||||
"""
|
||||
sentences = []
|
||||
buffer = text
|
||||
terminators = ".!?"
|
||||
|
||||
i = 0
|
||||
start = 0
|
||||
while i < len(buffer):
|
||||
if buffer[i] in terminators:
|
||||
# Look ahead for whitespace or end of string
|
||||
end = i + 1
|
||||
while end < len(buffer) and buffer[end] in terminators:
|
||||
end += 1
|
||||
if end >= len(buffer) or buffer[end] == " " or buffer[end] == "\n":
|
||||
sentence = buffer[start:end].strip()
|
||||
if sentence:
|
||||
sentences.append(sentence)
|
||||
start = end
|
||||
i = end
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
remaining = buffer[start:].strip()
|
||||
return sentences, remaining
|
||||
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from server.audio_utils import split_sentences
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Wraps Qwen3 for conversation generation."""
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful voice assistant. Keep your responses concise and natural "
|
||||
"for spoken conversation. Respond in 1-3 short sentences. "
|
||||
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
||||
"formatting that doesn't work in speech."
|
||||
)
|
||||
|
||||
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def _build_inputs(self, messages: list[dict]):
|
||||
"""Build input token ids using the model's chat template."""
|
||||
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}]
|
||||
for msg in messages:
|
||||
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
chat_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||
|
||||
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
||||
"""Generate a complete response (blocking)."""
|
||||
inputs = self._build_inputs(messages)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens (skip prompt)
|
||||
new_ids = output_ids[0][input_len:]
|
||||
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
||||
log.info(f"LLM response: {response}")
|
||||
return response
|
||||
|
||||
async def generate_sentences(
|
||||
self,
|
||||
messages: list[dict],
|
||||
cancel_event: threading.Event | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
||||
import asyncio
|
||||
|
||||
response = await asyncio.to_thread(self.generate, messages)
|
||||
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
|
||||
# Split into sentences and yield each
|
||||
sentences, remainder = split_sentences(response)
|
||||
for sentence in sentences:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
yield sentence
|
||||
|
||||
if remainder:
|
||||
yield remainder
|
||||
@@ -0,0 +1,87 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.params import Form
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from server.audio_utils import pcm_bytes_to_float32
|
||||
from server.models import ModelManager
|
||||
from server.pipeline import ConversationSession
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
||||
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||
|
||||
model_mgr = ModelManager()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
log.info("Starting model loading...")
|
||||
model_mgr.load_all()
|
||||
log.info("Server ready.")
|
||||
yield
|
||||
log.info("Shutting down.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index():
|
||||
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.post("/api/set-voice")
|
||||
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
||||
"""Change the TTS voice."""
|
||||
model_mgr.tts_engine.set_voice(voice, lang)
|
||||
return {"status": "ok", "voice": voice}
|
||||
|
||||
|
||||
@app.websocket("/ws/chat")
|
||||
async def websocket_chat(ws: WebSocket):
|
||||
await ws.accept()
|
||||
log.info("WebSocket client connected.")
|
||||
|
||||
async def send_json(data: dict):
|
||||
await ws.send_text(json.dumps(data))
|
||||
|
||||
async def send_bytes(data: bytes):
|
||||
await ws.send_bytes(data)
|
||||
|
||||
session = ConversationSession(model_mgr, send_json, send_bytes)
|
||||
await session.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await ws.receive()
|
||||
|
||||
if "bytes" in message:
|
||||
pcm_data = message["bytes"]
|
||||
chunk = pcm_bytes_to_float32(pcm_data)
|
||||
await session.handle_audio_chunk(chunk)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
msg = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if msg.get("type") == "interrupt":
|
||||
await session.interrupt()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
log.info("WebSocket client disconnected.")
|
||||
except Exception:
|
||||
log.exception("WebSocket error")
|
||||
finally:
|
||||
await session.stop()
|
||||
@@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from server.vad import StreamingVAD
|
||||
from server.asr import ASREngine
|
||||
from server.llm import LLMEngine
|
||||
from server.tts import TTSEngine
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Loads and holds all models. Initialized once at server startup."""
|
||||
|
||||
def __init__(self):
|
||||
self.vad_model = None
|
||||
self.asr_engine: ASREngine | None = None
|
||||
self.llm_engine: LLMEngine | None = None
|
||||
self.tts_engine: TTSEngine | None = None
|
||||
|
||||
def load_all(self):
|
||||
"""Load all models sequentially. Call from the main process."""
|
||||
self._load_vad()
|
||||
self._load_asr()
|
||||
self._load_llm()
|
||||
self._load_tts()
|
||||
log.info("All models loaded successfully.")
|
||||
|
||||
def _load_vad(self):
|
||||
log.info("Loading Silero VAD...")
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
self.vad_model = load_silero_vad()
|
||||
log.info("Silero VAD loaded (CPU).")
|
||||
|
||||
def _load_asr(self):
|
||||
log.info("Loading Qwen3-ASR-0.6B (transformers backend)...")
|
||||
from qwen_asr import Qwen3ASRModel
|
||||
|
||||
asr_model = Qwen3ASRModel.from_pretrained(
|
||||
"Qwen/Qwen3-ASR-0.6B",
|
||||
dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
self.asr_engine = ASREngine(asr_model)
|
||||
log.info("Qwen3-ASR-0.6B loaded.")
|
||||
|
||||
def _load_llm(self):
|
||||
log.info("Loading Qwen3-0.6B-Instruct...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
self.llm_engine = LLMEngine(model, tokenizer)
|
||||
log.info("Qwen3-0.6B-Instruct loaded.")
|
||||
|
||||
def _load_tts(self):
|
||||
log.info("Loading Kokoro TTS...")
|
||||
self.tts_engine = TTSEngine()
|
||||
log.info("Kokoro TTS loaded.")
|
||||
|
||||
def create_vad(self) -> StreamingVAD:
|
||||
"""Create a new StreamingVAD instance for a client session."""
|
||||
return StreamingVAD(self.vad_model)
|
||||
@@ -0,0 +1,164 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from server.audio_utils import float32_to_pcm_bytes
|
||||
from server.models import ModelManager
|
||||
from server.vad import StreamingVAD
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_SENTINEL = None
|
||||
|
||||
|
||||
class ConversationSession:
|
||||
"""Manages a single client's voice conversation pipeline.
|
||||
|
||||
Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support.
|
||||
"""
|
||||
|
||||
def __init__(self, models: ModelManager, send_json, send_bytes):
|
||||
self.models = models
|
||||
self.send_json = send_json
|
||||
self.send_bytes = send_bytes
|
||||
|
||||
self.vad: StreamingVAD = models.create_vad()
|
||||
self.conversation_history: list[dict] = []
|
||||
self.cancel_event = threading.Event()
|
||||
self.is_responding = False
|
||||
self._response_task: asyncio.Task | None = None
|
||||
|
||||
async def start(self):
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
|
||||
async def stop(self):
|
||||
self.cancel_event.set()
|
||||
if self._response_task and not self._response_task.done():
|
||||
self._response_task.cancel()
|
||||
|
||||
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
||||
utterance = self.vad.process_chunk(chunk_16k)
|
||||
|
||||
if utterance is not None:
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
# Launch response pipeline as a background task so we don't block receives
|
||||
self._response_task = asyncio.create_task(self._process_utterance(utterance))
|
||||
elif self.vad.is_speaking and self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def interrupt(self):
|
||||
"""Public interrupt method for WebSocket text messages."""
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def _interrupt(self):
|
||||
log.info("Barge-in: cancelling response.")
|
||||
self.cancel_event.set()
|
||||
self.is_responding = False
|
||||
# Tell client to stop audio immediately
|
||||
try:
|
||||
await self.send_json({"type": "interrupt"})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _process_utterance(self, audio_16k: np.ndarray):
|
||||
"""Full pipeline: ASR -> LLM -> TTS streaming."""
|
||||
self.is_responding = True
|
||||
self.cancel_event.clear()
|
||||
|
||||
# ASR
|
||||
await self.send_json({"type": "status", "state": "thinking"})
|
||||
text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k)
|
||||
|
||||
if not text:
|
||||
log.info("ASR returned empty text, resuming listening.")
|
||||
self.is_responding = False
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
return
|
||||
|
||||
await self.send_json({"type": "transcript", "text": text, "final": True})
|
||||
log.info(f"User: {text}")
|
||||
self.conversation_history.append({"role": "user", "content": text})
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
self.is_responding = False
|
||||
return
|
||||
|
||||
# LLM
|
||||
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
||||
+ str([m['content'][:50] for m in self.conversation_history]))
|
||||
response = await asyncio.to_thread(
|
||||
self.models.llm_engine.generate, self.conversation_history
|
||||
)
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
self.is_responding = False
|
||||
return
|
||||
|
||||
# TTS - stream chunks with per-sentence text
|
||||
await self.send_json({"type": "status", "state": "speaking"})
|
||||
chunk_queue = queue.Queue()
|
||||
|
||||
def _tts_worker():
|
||||
try:
|
||||
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
||||
response, voice=self.models.tts_engine.voice
|
||||
):
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
if audio is not None and len(audio) > 0:
|
||||
chunk_queue.put((graphemes, audio))
|
||||
except Exception:
|
||||
log.exception("TTS generation error")
|
||||
finally:
|
||||
chunk_queue.put(_SENTINEL)
|
||||
|
||||
tts_thread = threading.Thread(target=_tts_worker, daemon=True)
|
||||
tts_thread.start()
|
||||
|
||||
spoken_text = ""
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
|
||||
sentence_text, audio = item
|
||||
spoken_text += sentence_text
|
||||
|
||||
await self.send_json({"type": "response_text", "text": sentence_text, "final": False})
|
||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||
try:
|
||||
await self.send_bytes(pcm_bytes)
|
||||
except Exception:
|
||||
log.warning("Failed to send audio, client disconnected.")
|
||||
self.cancel_event.set()
|
||||
break
|
||||
|
||||
tts_thread.join(timeout=2.0)
|
||||
|
||||
# Save only what was actually spoken
|
||||
if spoken_text.strip():
|
||||
self.conversation_history.append(
|
||||
{"role": "assistant", "content": spoken_text.strip()}
|
||||
)
|
||||
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
||||
self.conversation_history.pop()
|
||||
|
||||
if not self.cancel_event.is_set():
|
||||
await self.send_json({"type": "response_text", "text": "", "final": True})
|
||||
|
||||
self.is_responding = False
|
||||
try:
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_VOICE = "af_heart"
|
||||
DEFAULT_LANG = "a" # American English
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
"""Wraps Kokoro TTS for fast streaming text-to-speech."""
|
||||
|
||||
def __init__(self):
|
||||
from kokoro import KPipeline
|
||||
|
||||
self.pipeline = KPipeline(lang_code=DEFAULT_LANG)
|
||||
self.voice = DEFAULT_VOICE
|
||||
self.sample_rate = 24000
|
||||
|
||||
def set_voice(self, voice: str, lang_code: str = "a"):
|
||||
"""Change the voice."""
|
||||
from kokoro import KPipeline
|
||||
|
||||
self.voice = voice
|
||||
self.pipeline = KPipeline(lang_code=lang_code)
|
||||
log.info(f"Voice set to: {voice} (lang: {lang_code})")
|
||||
|
||||
def synthesize_stream(self, text: str) -> Iterator[np.ndarray]:
|
||||
"""Yield audio chunks as they are generated.
|
||||
|
||||
Each chunk is a float32 numpy array at self.sample_rate (24kHz).
|
||||
Kokoro internally splits text into sentences and yields per-sentence audio.
|
||||
"""
|
||||
for _gs, _ps, audio in self.pipeline(text, voice=self.voice):
|
||||
if audio is not None and len(audio) > 0:
|
||||
yield audio
|
||||
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class StreamingVAD:
|
||||
"""Wraps Silero VAD for streaming chunk-by-chunk speech detection."""
|
||||
|
||||
def __init__(self, model, threshold: float = 0.5, min_silence_ms: int = 400):
|
||||
from silero_vad import VADIterator
|
||||
|
||||
self.iterator = VADIterator(
|
||||
model,
|
||||
sampling_rate=16000,
|
||||
threshold=threshold,
|
||||
min_silence_duration_ms=min_silence_ms,
|
||||
)
|
||||
self.audio_buffer: list[np.ndarray] = []
|
||||
self.is_speaking = False
|
||||
|
||||
def process_chunk(self, chunk_16k: np.ndarray) -> np.ndarray | None:
|
||||
"""Feed a 512-sample chunk at 16kHz.
|
||||
|
||||
Returns the complete utterance as a numpy array when speech ends,
|
||||
or None if still accumulating.
|
||||
"""
|
||||
tensor = torch.from_numpy(chunk_16k).float()
|
||||
speech_dict = self.iterator(tensor, return_seconds=False)
|
||||
|
||||
if speech_dict:
|
||||
if "start" in speech_dict:
|
||||
self.is_speaking = True
|
||||
self.audio_buffer = []
|
||||
if "end" in speech_dict:
|
||||
self.is_speaking = False
|
||||
if self.audio_buffer:
|
||||
result = np.concatenate(self.audio_buffer)
|
||||
self.audio_buffer = []
|
||||
self.iterator.reset_states()
|
||||
return result
|
||||
self.iterator.reset_states()
|
||||
return None
|
||||
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(chunk_16k.copy())
|
||||
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""Reset VAD state for a new conversation turn."""
|
||||
self.audio_buffer = []
|
||||
self.is_speaking = False
|
||||
self.iterator.reset_states()
|
||||
+305
@@ -0,0 +1,305 @@
|
||||
// --- State ---
|
||||
let ws = null;
|
||||
let audioCtx = null;
|
||||
let micStream = null;
|
||||
let workletNode = null;
|
||||
let micActive = false;
|
||||
let nextPlayTime = 0;
|
||||
let isPlaying = false;
|
||||
|
||||
const PLAYBACK_SR = 24000; // TTS output sample rate
|
||||
const MIC_SR = 16000;
|
||||
const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in
|
||||
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
|
||||
let bargeInCount = 0;
|
||||
|
||||
const chatArea = document.getElementById("chat-area");
|
||||
const statusBadge = document.getElementById("status-badge");
|
||||
const micBtn = document.getElementById("mic-btn");
|
||||
|
||||
// --- WebSocket ---
|
||||
|
||||
function connectWS() {
|
||||
const proto = location.protocol === "https:" ? "wss:" : "ws:";
|
||||
ws = new WebSocket(`${proto}//${location.host}/ws/chat`);
|
||||
ws.binaryType = "arraybuffer";
|
||||
|
||||
ws.onopen = () => {
|
||||
setStatus("listening");
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
setStatus("disconnected");
|
||||
setTimeout(connectWS, 2000);
|
||||
};
|
||||
|
||||
ws.onerror = () => {
|
||||
ws.close();
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
if (event.data instanceof ArrayBuffer) {
|
||||
playAudioChunk(event.data);
|
||||
} else {
|
||||
handleJSON(JSON.parse(event.data));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function handleJSON(msg) {
|
||||
switch (msg.type) {
|
||||
case "status":
|
||||
setStatus(msg.state);
|
||||
break;
|
||||
|
||||
case "interrupt":
|
||||
stopPlayback();
|
||||
// Trim the assistant message to what was spoken, then finalize
|
||||
finalizeAssistantMessage();
|
||||
break;
|
||||
|
||||
case "transcript":
|
||||
addMessage("user", msg.text);
|
||||
break;
|
||||
|
||||
case "response_text":
|
||||
if (msg.final) {
|
||||
finalizeAssistantMessage();
|
||||
} else {
|
||||
appendAssistantText(msg.text);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Status ---
|
||||
|
||||
function setStatus(state) {
|
||||
statusBadge.textContent =
|
||||
state === "listening"
|
||||
? "Listening"
|
||||
: state === "thinking"
|
||||
? "Thinking..."
|
||||
: state === "speaking"
|
||||
? "Speaking"
|
||||
: state === "disconnected"
|
||||
? "Disconnected"
|
||||
: state;
|
||||
statusBadge.className = state;
|
||||
}
|
||||
|
||||
// --- Chat Messages ---
|
||||
|
||||
let currentAssistantEl = null;
|
||||
let currentAssistantText = "";
|
||||
|
||||
function addMessage(role, text) {
|
||||
const el = document.createElement("div");
|
||||
el.className = `message ${role}`;
|
||||
el.textContent = text;
|
||||
chatArea.appendChild(el);
|
||||
chatArea.scrollTop = chatArea.scrollHeight;
|
||||
}
|
||||
|
||||
function appendAssistantText(text) {
|
||||
if (!currentAssistantEl) {
|
||||
currentAssistantEl = document.createElement("div");
|
||||
currentAssistantEl.className = "message assistant";
|
||||
chatArea.appendChild(currentAssistantEl);
|
||||
currentAssistantText = "";
|
||||
}
|
||||
currentAssistantText += (currentAssistantText ? " " : "") + text;
|
||||
currentAssistantEl.textContent = currentAssistantText;
|
||||
chatArea.scrollTop = chatArea.scrollHeight;
|
||||
}
|
||||
|
||||
function finalizeAssistantMessage() {
|
||||
currentAssistantEl = null;
|
||||
currentAssistantText = "";
|
||||
}
|
||||
|
||||
// --- Audio Playback ---
|
||||
|
||||
let activeSources = [];
|
||||
|
||||
function getPlaybackCtx() {
|
||||
if (!audioCtx || audioCtx.state === "closed") {
|
||||
audioCtx = new AudioContext({ sampleRate: PLAYBACK_SR });
|
||||
}
|
||||
return audioCtx;
|
||||
}
|
||||
|
||||
function playAudioChunk(arrayBuffer) {
|
||||
const ctx = getPlaybackCtx();
|
||||
const int16 = new Int16Array(arrayBuffer);
|
||||
const float32 = new Float32Array(int16.length);
|
||||
for (let i = 0; i < int16.length; i++) {
|
||||
float32[i] = int16[i] / 32768;
|
||||
}
|
||||
|
||||
const buffer = ctx.createBuffer(1, float32.length, PLAYBACK_SR);
|
||||
buffer.getChannelData(0).set(float32);
|
||||
|
||||
const source = ctx.createBufferSource();
|
||||
source.buffer = buffer;
|
||||
source.connect(ctx.destination);
|
||||
|
||||
activeSources.push(source);
|
||||
isPlaying = true;
|
||||
source.onended = () => {
|
||||
activeSources = activeSources.filter((s) => s !== source);
|
||||
if (activeSources.length === 0) {
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
}
|
||||
};
|
||||
|
||||
const now = ctx.currentTime;
|
||||
if (nextPlayTime < now) {
|
||||
nextPlayTime = now + 0.01;
|
||||
}
|
||||
source.start(nextPlayTime);
|
||||
nextPlayTime += buffer.duration;
|
||||
}
|
||||
|
||||
function stopPlayback() {
|
||||
for (const source of activeSources) {
|
||||
try {
|
||||
source.stop();
|
||||
} catch (_) {}
|
||||
}
|
||||
activeSources = [];
|
||||
nextPlayTime = 0;
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
}
|
||||
|
||||
// --- Microphone ---
|
||||
|
||||
async function toggleMic() {
|
||||
if (micActive) {
|
||||
stopMic();
|
||||
} else {
|
||||
await startMic();
|
||||
}
|
||||
}
|
||||
|
||||
async function startMic() {
|
||||
try {
|
||||
// Ensure playback context exists (needed for user gesture)
|
||||
getPlaybackCtx();
|
||||
if (audioCtx.state === "suspended") {
|
||||
await audioCtx.resume();
|
||||
}
|
||||
|
||||
micStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: {
|
||||
sampleRate: MIC_SR,
|
||||
channelCount: 1,
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
autoGainControl: true,
|
||||
},
|
||||
});
|
||||
|
||||
// Create a separate context at 16kHz for mic capture
|
||||
const micCtx = new AudioContext({ sampleRate: MIC_SR });
|
||||
const source = micCtx.createMediaStreamSource(micStream);
|
||||
|
||||
await micCtx.audioWorklet.addModule("/static/processor.js");
|
||||
workletNode = new AudioWorkletNode(micCtx, "pcm-processor");
|
||||
source.connect(workletNode);
|
||||
|
||||
workletNode.port.onmessage = (e) => {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(e.data);
|
||||
|
||||
// Client-side barge-in: detect mic energy while playing
|
||||
if (isPlaying) {
|
||||
const samples = new Int16Array(e.data);
|
||||
let sum = 0;
|
||||
for (let i = 0; i < samples.length; i++) {
|
||||
const s = samples[i] / 32768;
|
||||
sum += s * s;
|
||||
}
|
||||
const rms = Math.sqrt(sum / samples.length);
|
||||
|
||||
if (rms > BARGE_IN_THRESHOLD) {
|
||||
bargeInCount++;
|
||||
if (bargeInCount >= BARGE_IN_FRAMES) {
|
||||
// User is speaking over the assistant - interrupt
|
||||
stopPlayback();
|
||||
finalizeAssistantMessage();
|
||||
ws.send(JSON.stringify({ type: "interrupt" }));
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
}
|
||||
} else {
|
||||
bargeInCount = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Store for cleanup
|
||||
workletNode._micCtx = micCtx;
|
||||
|
||||
micActive = true;
|
||||
micBtn.classList.add("active");
|
||||
|
||||
// Connect WebSocket if not already
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
connectWS();
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Mic access failed:", err);
|
||||
alert("Could not access microphone. Please allow mic permissions.");
|
||||
}
|
||||
}
|
||||
|
||||
function stopMic() {
|
||||
if (workletNode) {
|
||||
workletNode.disconnect();
|
||||
if (workletNode._micCtx) {
|
||||
workletNode._micCtx.close();
|
||||
}
|
||||
workletNode = null;
|
||||
}
|
||||
if (micStream) {
|
||||
micStream.getTracks().forEach((t) => t.stop());
|
||||
micStream = null;
|
||||
}
|
||||
micActive = false;
|
||||
micBtn.classList.remove("active");
|
||||
}
|
||||
|
||||
// --- Voice Selection ---
|
||||
|
||||
async function applyVoice() {
|
||||
const voice = document.getElementById("voice-select").value;
|
||||
const statusEl = document.getElementById("voice-status");
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append("voice", voice);
|
||||
formData.append("lang", "a");
|
||||
|
||||
statusEl.textContent = "Applying...";
|
||||
try {
|
||||
const resp = await fetch("/api/set-voice", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
const data = await resp.json();
|
||||
if (data.status === "ok") {
|
||||
statusEl.textContent = "Voice: " + voice;
|
||||
} else {
|
||||
statusEl.textContent = "Failed.";
|
||||
}
|
||||
} catch (err) {
|
||||
statusEl.textContent = "Error: " + err.message;
|
||||
}
|
||||
}
|
||||
|
||||
// Expose to HTML onclick
|
||||
window.toggleMic = toggleMic;
|
||||
window.applyVoice = applyVoice;
|
||||
@@ -0,0 +1,49 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Voice Chat</title>
|
||||
<link rel="stylesheet" href="/static/style.css" />
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>Voice Chat</h1>
|
||||
<span id="status-badge">Disconnected</span>
|
||||
</header>
|
||||
|
||||
<div id="chat-area"></div>
|
||||
|
||||
<details id="voice-panel">
|
||||
<summary>Voice Settings</summary>
|
||||
<div class="panel-content">
|
||||
<label>
|
||||
Voice
|
||||
<select id="voice-select">
|
||||
<optgroup label="Female">
|
||||
<option value="af_heart" selected>Heart</option>
|
||||
<option value="af_nicole">Nicole</option>
|
||||
<option value="af_bella">Bella</option>
|
||||
<option value="af_sarah">Sarah</option>
|
||||
<option value="af_nova">Nova</option>
|
||||
<option value="af_jessica">Jessica</option>
|
||||
<option value="af_river">River</option>
|
||||
</optgroup>
|
||||
<optgroup label="Male">
|
||||
<option value="am_adam">Adam</option>
|
||||
<option value="am_michael">Michael</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
</label>
|
||||
<button id="apply-voice-btn" onclick="applyVoice()">Apply</button>
|
||||
<span id="voice-status"></span>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<div id="controls">
|
||||
<button id="mic-btn" onclick="toggleMic()">🎤</button>
|
||||
</div>
|
||||
|
||||
<script src="/static/app.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* AudioWorkletProcessor that collects 512-sample chunks of PCM audio
|
||||
* and posts them to the main thread for WebSocket transmission.
|
||||
*/
|
||||
class PCMProcessor extends AudioWorkletProcessor {
|
||||
constructor() {
|
||||
super();
|
||||
this.buffer = new Float32Array(0);
|
||||
this.chunkSize = 512; // 512 samples at 16kHz = 32ms
|
||||
}
|
||||
|
||||
process(inputs) {
|
||||
const input = inputs[0];
|
||||
if (!input || !input[0]) return true;
|
||||
|
||||
const channelData = input[0]; // mono
|
||||
|
||||
// Append to buffer
|
||||
const newBuffer = new Float32Array(this.buffer.length + channelData.length);
|
||||
newBuffer.set(this.buffer);
|
||||
newBuffer.set(channelData, this.buffer.length);
|
||||
this.buffer = newBuffer;
|
||||
|
||||
// Send complete chunks
|
||||
while (this.buffer.length >= this.chunkSize) {
|
||||
const chunk = this.buffer.slice(0, this.chunkSize);
|
||||
this.buffer = this.buffer.slice(this.chunkSize);
|
||||
|
||||
// Convert float32 to int16 for transmission
|
||||
const int16 = new Int16Array(chunk.length);
|
||||
for (let i = 0; i < chunk.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, chunk[i]));
|
||||
int16[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
|
||||
}
|
||||
this.port.postMessage(int16.buffer, [int16.buffer]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
registerProcessor("pcm-processor", PCMProcessor);
|
||||
@@ -0,0 +1,185 @@
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
background: #0f0f0f;
|
||||
color: #e0e0e0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
header {
|
||||
padding: 16px 24px;
|
||||
border-bottom: 1px solid #222;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
header h1 {
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#status-badge {
|
||||
padding: 4px 12px;
|
||||
border-radius: 12px;
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
background: #1a1a2e;
|
||||
color: #888;
|
||||
transition: all 0.3s;
|
||||
}
|
||||
|
||||
#status-badge.listening {
|
||||
background: #0a2a1a;
|
||||
color: #4ade80;
|
||||
}
|
||||
|
||||
#status-badge.thinking {
|
||||
background: #2a1a0a;
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
#status-badge.speaking {
|
||||
background: #1a0a2a;
|
||||
color: #a78bfa;
|
||||
}
|
||||
|
||||
#chat-area {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 24px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.message {
|
||||
max-width: 70%;
|
||||
padding: 10px 16px;
|
||||
border-radius: 16px;
|
||||
font-size: 15px;
|
||||
line-height: 1.5;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
|
||||
.message.user {
|
||||
align-self: flex-end;
|
||||
background: #1d4ed8;
|
||||
color: #fff;
|
||||
border-bottom-right-radius: 4px;
|
||||
}
|
||||
|
||||
.message.assistant {
|
||||
align-self: flex-start;
|
||||
background: #1e1e1e;
|
||||
color: #e0e0e0;
|
||||
border-bottom-left-radius: 4px;
|
||||
}
|
||||
|
||||
#controls {
|
||||
padding: 16px 24px;
|
||||
border-top: 1px solid #222;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
#mic-btn {
|
||||
width: 56px;
|
||||
height: 56px;
|
||||
border-radius: 50%;
|
||||
border: 2px solid #333;
|
||||
background: #1a1a1a;
|
||||
color: #e0e0e0;
|
||||
font-size: 24px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#mic-btn:hover {
|
||||
border-color: #555;
|
||||
background: #222;
|
||||
}
|
||||
|
||||
#mic-btn.active {
|
||||
border-color: #ef4444;
|
||||
background: #2a0a0a;
|
||||
color: #ef4444;
|
||||
animation: pulse 1.5s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { box-shadow: 0 0 0 0 rgba(239, 68, 68, 0.3); }
|
||||
50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); }
|
||||
}
|
||||
|
||||
/* Voice clone panel */
|
||||
#voice-panel {
|
||||
padding: 12px 24px;
|
||||
border-top: 1px solid #222;
|
||||
background: #0a0a0a;
|
||||
}
|
||||
|
||||
#voice-panel summary {
|
||||
cursor: pointer;
|
||||
font-size: 13px;
|
||||
color: #888;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
#voice-panel .panel-content {
|
||||
margin-top: 12px;
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
align-items: flex-end;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
#voice-panel label {
|
||||
font-size: 13px;
|
||||
color: #aaa;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
#voice-panel input[type="file"],
|
||||
#voice-panel input[type="text"] {
|
||||
background: #1a1a1a;
|
||||
border: 1px solid #333;
|
||||
border-radius: 6px;
|
||||
padding: 6px 10px;
|
||||
color: #e0e0e0;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
#upload-btn {
|
||||
padding: 6px 16px;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #333;
|
||||
background: #1a1a1a;
|
||||
color: #e0e0e0;
|
||||
font-size: 13px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#upload-btn:hover {
|
||||
background: #222;
|
||||
}
|
||||
|
||||
#upload-status {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
margin-left: 8px;
|
||||
}
|
||||
Reference in New Issue
Block a user