commit ce41bca42226e83174d7998a88720077707913c8 Author: Brian Date: Tue Apr 7 03:58:35 2026 -0400 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0e5ac79 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.venv +__pycache__ \ No newline at end of file diff --git a/reference_audio/JARVIS II.mp3 b/reference_audio/JARVIS II.mp3 new file mode 100644 index 0000000..cd445e2 Binary files /dev/null and b/reference_audio/JARVIS II.mp3 differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bc1e861 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch>=2.5.0 +transformers==4.57.6 +silero-vad>=5.1 +fastapi>=0.115.0 +uvicorn[standard]>=0.30.0 +numpy +soundfile +scipy +python-multipart diff --git a/run.py b/run.py new file mode 100644 index 0000000..9092129 --- /dev/null +++ b/run.py @@ -0,0 +1,10 @@ +import uvicorn + +if __name__ == "__main__": + uvicorn.run( + "server.main:app", + host="0.0.0.0", + port=8000, + reload=False, + log_level="info", + ) diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/asr.py b/server/asr.py new file mode 100644 index 0000000..737a74f --- /dev/null +++ b/server/asr.py @@ -0,0 +1,25 @@ +import numpy as np + + +class ASREngine: + """Wraps Qwen3-ASR for speech-to-text transcription.""" + + def __init__(self, model): + self.model = model + + def transcribe(self, audio_16k: np.ndarray) -> str: + """Transcribe a complete utterance. + + Args: + audio_16k: Float32 numpy array at 16kHz sample rate. + + Returns: + Transcribed text string. + """ + results = self.model.transcribe( + audio=(audio_16k, 16000), + language=None, # auto-detect + ) + if results and results[0].text: + return results[0].text.strip() + return "" diff --git a/server/audio_utils.py b/server/audio_utils.py new file mode 100644 index 0000000..bba0739 --- /dev/null +++ b/server/audio_utils.py @@ -0,0 +1,63 @@ +import numpy as np +from scipy.signal import resample_poly +from math import gcd + + +def pcm_bytes_to_float32(pcm_bytes: bytes, dtype=np.int16) -> np.ndarray: + """Convert raw PCM bytes (16-bit signed int) to float32 in [-1, 1].""" + audio = np.frombuffer(pcm_bytes, dtype=dtype) + return audio.astype(np.float32) / 32768.0 + + +def float32_to_pcm_bytes(audio) -> bytes: + """Convert float32 audio in [-1, 1] to 16-bit PCM bytes. + + Accepts numpy arrays or PyTorch tensors. + """ + if not isinstance(audio, np.ndarray): + audio = audio.detach().cpu().numpy() + clamped = np.clip(audio, -1.0, 1.0) + return (clamped * 32767).astype(np.int16).tobytes() + + +def resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: + """Resample audio from orig_sr to target_sr using polyphase filtering.""" + if orig_sr == target_sr: + return audio + divisor = gcd(orig_sr, target_sr) + up = target_sr // divisor + down = orig_sr // divisor + return resample_poly(audio, up, down).astype(audio.dtype) + + +def split_sentences(text: str) -> tuple[list[str], str]: + """Split text into completed sentences and a remaining buffer. + + Returns (sentences, remaining_buffer). + Splits on sentence-ending punctuation followed by whitespace. + """ + sentences = [] + buffer = text + terminators = ".!?" + + i = 0 + start = 0 + while i < len(buffer): + if buffer[i] in terminators: + # Look ahead for whitespace or end of string + end = i + 1 + while end < len(buffer) and buffer[end] in terminators: + end += 1 + if end >= len(buffer) or buffer[end] == " " or buffer[end] == "\n": + sentence = buffer[start:end].strip() + if sentence: + sentences.append(sentence) + start = end + i = end + else: + i += 1 + else: + i += 1 + + remaining = buffer[start:].strip() + return sentences, remaining diff --git a/server/llm.py b/server/llm.py new file mode 100644 index 0000000..d6fd2fc --- /dev/null +++ b/server/llm.py @@ -0,0 +1,83 @@ +import logging +import threading +from typing import AsyncIterator + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from server.audio_utils import split_sentences + +log = logging.getLogger(__name__) + + +class LLMEngine: + """Wraps Qwen3 for conversation generation.""" + + SYSTEM_PROMPT = ( + "You are a helpful voice assistant. Keep your responses concise and natural " + "for spoken conversation. Respond in 1-3 short sentences. " + "Do not use markdown, bullet points, code blocks, emojis, or any " + "formatting that doesn't work in speech." + ) + + def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer): + self.model = model + self.tokenizer = tokenizer + + def _build_inputs(self, messages: list[dict]): + """Build input token ids using the model's chat template.""" + chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}] + for msg in messages: + chat_messages.append({"role": msg["role"], "content": msg["content"]}) + + text = self.tokenizer.apply_chat_template( + chat_messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + return self.tokenizer(text, return_tensors="pt").to(self.model.device) + + def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: + """Generate a complete response (blocking).""" + inputs = self._build_inputs(messages) + input_len = inputs["input_ids"].shape[1] + + with torch.no_grad(): + output_ids = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=0.7, + top_p=0.9, + do_sample=True, + repetition_penalty=1.2, + ) + + # Decode only the generated tokens (skip prompt) + new_ids = output_ids[0][input_len:] + response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() + log.info(f"LLM response: {response}") + return response + + async def generate_sentences( + self, + messages: list[dict], + cancel_event: threading.Event | None = None, + ) -> AsyncIterator[str]: + """Generate response and yield it sentence by sentence for TTS pipelining.""" + import asyncio + + response = await asyncio.to_thread(self.generate, messages) + + if cancel_event and cancel_event.is_set(): + return + + # Split into sentences and yield each + sentences, remainder = split_sentences(response) + for sentence in sentences: + if cancel_event and cancel_event.is_set(): + return + yield sentence + + if remainder: + yield remainder diff --git a/server/main.py b/server/main.py new file mode 100644 index 0000000..6d3fdb3 --- /dev/null +++ b/server/main.py @@ -0,0 +1,87 @@ +import json +import logging +import os +from contextlib import asynccontextmanager + +import numpy as np +from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect +from fastapi.params import Form +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles + +from server.audio_utils import pcm_bytes_to_float32 +from server.models import ModelManager +from server.pipeline import ConversationSession + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +log = logging.getLogger(__name__) + +REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio") +STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") + +model_mgr = ModelManager() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + log.info("Starting model loading...") + model_mgr.load_all() + log.info("Server ready.") + yield + log.info("Shutting down.") + + +app = FastAPI(lifespan=lifespan) +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + + +@app.get("/") +async def index(): + return FileResponse(os.path.join(STATIC_DIR, "index.html")) + + +@app.post("/api/set-voice") +async def set_voice(voice: str = Form(...), lang: str = Form("a")): + """Change the TTS voice.""" + model_mgr.tts_engine.set_voice(voice, lang) + return {"status": "ok", "voice": voice} + + +@app.websocket("/ws/chat") +async def websocket_chat(ws: WebSocket): + await ws.accept() + log.info("WebSocket client connected.") + + async def send_json(data: dict): + await ws.send_text(json.dumps(data)) + + async def send_bytes(data: bytes): + await ws.send_bytes(data) + + session = ConversationSession(model_mgr, send_json, send_bytes) + await session.start() + + try: + while True: + message = await ws.receive() + + if "bytes" in message: + pcm_data = message["bytes"] + chunk = pcm_bytes_to_float32(pcm_data) + await session.handle_audio_chunk(chunk) + + elif "text" in message: + try: + msg = json.loads(message["text"]) + except json.JSONDecodeError: + continue + + if msg.get("type") == "interrupt": + await session.interrupt() + + except WebSocketDisconnect: + log.info("WebSocket client disconnected.") + except Exception: + log.exception("WebSocket error") + finally: + await session.stop() diff --git a/server/models.py b/server/models.py new file mode 100644 index 0000000..e02d521 --- /dev/null +++ b/server/models.py @@ -0,0 +1,70 @@ +import logging +import torch + +from server.vad import StreamingVAD +from server.asr import ASREngine +from server.llm import LLMEngine +from server.tts import TTSEngine + +log = logging.getLogger(__name__) + + +class ModelManager: + """Loads and holds all models. Initialized once at server startup.""" + + def __init__(self): + self.vad_model = None + self.asr_engine: ASREngine | None = None + self.llm_engine: LLMEngine | None = None + self.tts_engine: TTSEngine | None = None + + def load_all(self): + """Load all models sequentially. Call from the main process.""" + self._load_vad() + self._load_asr() + self._load_llm() + self._load_tts() + log.info("All models loaded successfully.") + + def _load_vad(self): + log.info("Loading Silero VAD...") + from silero_vad import load_silero_vad + + self.vad_model = load_silero_vad() + log.info("Silero VAD loaded (CPU).") + + def _load_asr(self): + log.info("Loading Qwen3-ASR-0.6B (transformers backend)...") + from qwen_asr import Qwen3ASRModel + + asr_model = Qwen3ASRModel.from_pretrained( + "Qwen/Qwen3-ASR-0.6B", + dtype=torch.bfloat16, + device_map="cuda:0", + max_new_tokens=4096, + ) + self.asr_engine = ASREngine(asr_model) + log.info("Qwen3-ASR-0.6B loaded.") + + def _load_llm(self): + log.info("Loading Qwen3-0.6B-Instruct...") + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_name = "Qwen/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="cuda:0", + ) + self.llm_engine = LLMEngine(model, tokenizer) + log.info("Qwen3-0.6B-Instruct loaded.") + + def _load_tts(self): + log.info("Loading Kokoro TTS...") + self.tts_engine = TTSEngine() + log.info("Kokoro TTS loaded.") + + def create_vad(self) -> StreamingVAD: + """Create a new StreamingVAD instance for a client session.""" + return StreamingVAD(self.vad_model) diff --git a/server/pipeline.py b/server/pipeline.py new file mode 100644 index 0000000..064d3ac --- /dev/null +++ b/server/pipeline.py @@ -0,0 +1,164 @@ +import asyncio +import logging +import queue +import threading + +import numpy as np + +from server.audio_utils import float32_to_pcm_bytes +from server.models import ModelManager +from server.vad import StreamingVAD + +log = logging.getLogger(__name__) + +_SENTINEL = None + + +class ConversationSession: + """Manages a single client's voice conversation pipeline. + + Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support. + """ + + def __init__(self, models: ModelManager, send_json, send_bytes): + self.models = models + self.send_json = send_json + self.send_bytes = send_bytes + + self.vad: StreamingVAD = models.create_vad() + self.conversation_history: list[dict] = [] + self.cancel_event = threading.Event() + self.is_responding = False + self._response_task: asyncio.Task | None = None + + async def start(self): + await self.send_json({"type": "status", "state": "listening"}) + + async def stop(self): + self.cancel_event.set() + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + + async def handle_audio_chunk(self, chunk_16k: np.ndarray): + utterance = self.vad.process_chunk(chunk_16k) + + if utterance is not None: + if self.is_responding: + await self._interrupt() + # Launch response pipeline as a background task so we don't block receives + self._response_task = asyncio.create_task(self._process_utterance(utterance)) + elif self.vad.is_speaking and self.is_responding: + await self._interrupt() + + async def interrupt(self): + """Public interrupt method for WebSocket text messages.""" + if self.is_responding: + await self._interrupt() + + async def _interrupt(self): + log.info("Barge-in: cancelling response.") + self.cancel_event.set() + self.is_responding = False + # Tell client to stop audio immediately + try: + await self.send_json({"type": "interrupt"}) + except Exception: + pass + + async def _process_utterance(self, audio_16k: np.ndarray): + """Full pipeline: ASR -> LLM -> TTS streaming.""" + self.is_responding = True + self.cancel_event.clear() + + # ASR + await self.send_json({"type": "status", "state": "thinking"}) + text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k) + + if not text: + log.info("ASR returned empty text, resuming listening.") + self.is_responding = False + await self.send_json({"type": "status", "state": "listening"}) + return + + await self.send_json({"type": "transcript", "text": text, "final": True}) + log.info(f"User: {text}") + self.conversation_history.append({"role": "user", "content": text}) + + if self.cancel_event.is_set(): + self.is_responding = False + return + + # LLM + log.info(f"Conversation history ({len(self.conversation_history)} messages): " + + str([m['content'][:50] for m in self.conversation_history])) + response = await asyncio.to_thread( + self.models.llm_engine.generate, self.conversation_history + ) + + if self.cancel_event.is_set(): + self.is_responding = False + return + + # TTS - stream chunks with per-sentence text + await self.send_json({"type": "status", "state": "speaking"}) + chunk_queue = queue.Queue() + + def _tts_worker(): + try: + for graphemes, _ps, audio in self.models.tts_engine.pipeline( + response, voice=self.models.tts_engine.voice + ): + if self.cancel_event.is_set(): + break + if audio is not None and len(audio) > 0: + chunk_queue.put((graphemes, audio)) + except Exception: + log.exception("TTS generation error") + finally: + chunk_queue.put(_SENTINEL) + + tts_thread = threading.Thread(target=_tts_worker, daemon=True) + tts_thread.start() + + spoken_text = "" + while True: + try: + item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) + except Exception: + break + + if item is _SENTINEL: + break + if self.cancel_event.is_set(): + break + + sentence_text, audio = item + spoken_text += sentence_text + + await self.send_json({"type": "response_text", "text": sentence_text, "final": False}) + pcm_bytes = float32_to_pcm_bytes(audio) + try: + await self.send_bytes(pcm_bytes) + except Exception: + log.warning("Failed to send audio, client disconnected.") + self.cancel_event.set() + break + + tts_thread.join(timeout=2.0) + + # Save only what was actually spoken + if spoken_text.strip(): + self.conversation_history.append( + {"role": "assistant", "content": spoken_text.strip()} + ) + elif self.conversation_history and self.conversation_history[-1]["role"] == "user": + self.conversation_history.pop() + + if not self.cancel_event.is_set(): + await self.send_json({"type": "response_text", "text": "", "final": True}) + + self.is_responding = False + try: + await self.send_json({"type": "status", "state": "listening"}) + except Exception: + pass diff --git a/server/tts.py b/server/tts.py new file mode 100644 index 0000000..5996dda --- /dev/null +++ b/server/tts.py @@ -0,0 +1,38 @@ +import logging +from typing import Iterator + +import numpy as np + +log = logging.getLogger(__name__) + +DEFAULT_VOICE = "af_heart" +DEFAULT_LANG = "a" # American English + + +class TTSEngine: + """Wraps Kokoro TTS for fast streaming text-to-speech.""" + + def __init__(self): + from kokoro import KPipeline + + self.pipeline = KPipeline(lang_code=DEFAULT_LANG) + self.voice = DEFAULT_VOICE + self.sample_rate = 24000 + + def set_voice(self, voice: str, lang_code: str = "a"): + """Change the voice.""" + from kokoro import KPipeline + + self.voice = voice + self.pipeline = KPipeline(lang_code=lang_code) + log.info(f"Voice set to: {voice} (lang: {lang_code})") + + def synthesize_stream(self, text: str) -> Iterator[np.ndarray]: + """Yield audio chunks as they are generated. + + Each chunk is a float32 numpy array at self.sample_rate (24kHz). + Kokoro internally splits text into sentences and yields per-sentence audio. + """ + for _gs, _ps, audio in self.pipeline(text, voice=self.voice): + if audio is not None and len(audio) > 0: + yield audio diff --git a/server/vad.py b/server/vad.py new file mode 100644 index 0000000..887f94c --- /dev/null +++ b/server/vad.py @@ -0,0 +1,52 @@ +import numpy as np +import torch + + +class StreamingVAD: + """Wraps Silero VAD for streaming chunk-by-chunk speech detection.""" + + def __init__(self, model, threshold: float = 0.5, min_silence_ms: int = 400): + from silero_vad import VADIterator + + self.iterator = VADIterator( + model, + sampling_rate=16000, + threshold=threshold, + min_silence_duration_ms=min_silence_ms, + ) + self.audio_buffer: list[np.ndarray] = [] + self.is_speaking = False + + def process_chunk(self, chunk_16k: np.ndarray) -> np.ndarray | None: + """Feed a 512-sample chunk at 16kHz. + + Returns the complete utterance as a numpy array when speech ends, + or None if still accumulating. + """ + tensor = torch.from_numpy(chunk_16k).float() + speech_dict = self.iterator(tensor, return_seconds=False) + + if speech_dict: + if "start" in speech_dict: + self.is_speaking = True + self.audio_buffer = [] + if "end" in speech_dict: + self.is_speaking = False + if self.audio_buffer: + result = np.concatenate(self.audio_buffer) + self.audio_buffer = [] + self.iterator.reset_states() + return result + self.iterator.reset_states() + return None + + if self.is_speaking: + self.audio_buffer.append(chunk_16k.copy()) + + return None + + def reset(self): + """Reset VAD state for a new conversation turn.""" + self.audio_buffer = [] + self.is_speaking = False + self.iterator.reset_states() diff --git a/static/app.js b/static/app.js new file mode 100644 index 0000000..f234d8c --- /dev/null +++ b/static/app.js @@ -0,0 +1,305 @@ +// --- State --- +let ws = null; +let audioCtx = null; +let micStream = null; +let workletNode = null; +let micActive = false; +let nextPlayTime = 0; +let isPlaying = false; + +const PLAYBACK_SR = 24000; // TTS output sample rate +const MIC_SR = 16000; +const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in +const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger +let bargeInCount = 0; + +const chatArea = document.getElementById("chat-area"); +const statusBadge = document.getElementById("status-badge"); +const micBtn = document.getElementById("mic-btn"); + +// --- WebSocket --- + +function connectWS() { + const proto = location.protocol === "https:" ? "wss:" : "ws:"; + ws = new WebSocket(`${proto}//${location.host}/ws/chat`); + ws.binaryType = "arraybuffer"; + + ws.onopen = () => { + setStatus("listening"); + }; + + ws.onclose = () => { + setStatus("disconnected"); + setTimeout(connectWS, 2000); + }; + + ws.onerror = () => { + ws.close(); + }; + + ws.onmessage = (event) => { + if (event.data instanceof ArrayBuffer) { + playAudioChunk(event.data); + } else { + handleJSON(JSON.parse(event.data)); + } + }; +} + +function handleJSON(msg) { + switch (msg.type) { + case "status": + setStatus(msg.state); + break; + + case "interrupt": + stopPlayback(); + // Trim the assistant message to what was spoken, then finalize + finalizeAssistantMessage(); + break; + + case "transcript": + addMessage("user", msg.text); + break; + + case "response_text": + if (msg.final) { + finalizeAssistantMessage(); + } else { + appendAssistantText(msg.text); + } + break; + } +} + +// --- Status --- + +function setStatus(state) { + statusBadge.textContent = + state === "listening" + ? "Listening" + : state === "thinking" + ? "Thinking..." + : state === "speaking" + ? "Speaking" + : state === "disconnected" + ? "Disconnected" + : state; + statusBadge.className = state; +} + +// --- Chat Messages --- + +let currentAssistantEl = null; +let currentAssistantText = ""; + +function addMessage(role, text) { + const el = document.createElement("div"); + el.className = `message ${role}`; + el.textContent = text; + chatArea.appendChild(el); + chatArea.scrollTop = chatArea.scrollHeight; +} + +function appendAssistantText(text) { + if (!currentAssistantEl) { + currentAssistantEl = document.createElement("div"); + currentAssistantEl.className = "message assistant"; + chatArea.appendChild(currentAssistantEl); + currentAssistantText = ""; + } + currentAssistantText += (currentAssistantText ? " " : "") + text; + currentAssistantEl.textContent = currentAssistantText; + chatArea.scrollTop = chatArea.scrollHeight; +} + +function finalizeAssistantMessage() { + currentAssistantEl = null; + currentAssistantText = ""; +} + +// --- Audio Playback --- + +let activeSources = []; + +function getPlaybackCtx() { + if (!audioCtx || audioCtx.state === "closed") { + audioCtx = new AudioContext({ sampleRate: PLAYBACK_SR }); + } + return audioCtx; +} + +function playAudioChunk(arrayBuffer) { + const ctx = getPlaybackCtx(); + const int16 = new Int16Array(arrayBuffer); + const float32 = new Float32Array(int16.length); + for (let i = 0; i < int16.length; i++) { + float32[i] = int16[i] / 32768; + } + + const buffer = ctx.createBuffer(1, float32.length, PLAYBACK_SR); + buffer.getChannelData(0).set(float32); + + const source = ctx.createBufferSource(); + source.buffer = buffer; + source.connect(ctx.destination); + + activeSources.push(source); + isPlaying = true; + source.onended = () => { + activeSources = activeSources.filter((s) => s !== source); + if (activeSources.length === 0) { + isPlaying = false; + bargeInCount = 0; + } + }; + + const now = ctx.currentTime; + if (nextPlayTime < now) { + nextPlayTime = now + 0.01; + } + source.start(nextPlayTime); + nextPlayTime += buffer.duration; +} + +function stopPlayback() { + for (const source of activeSources) { + try { + source.stop(); + } catch (_) {} + } + activeSources = []; + nextPlayTime = 0; + isPlaying = false; + bargeInCount = 0; +} + +// --- Microphone --- + +async function toggleMic() { + if (micActive) { + stopMic(); + } else { + await startMic(); + } +} + +async function startMic() { + try { + // Ensure playback context exists (needed for user gesture) + getPlaybackCtx(); + if (audioCtx.state === "suspended") { + await audioCtx.resume(); + } + + micStream = await navigator.mediaDevices.getUserMedia({ + audio: { + sampleRate: MIC_SR, + channelCount: 1, + echoCancellation: true, + noiseSuppression: true, + autoGainControl: true, + }, + }); + + // Create a separate context at 16kHz for mic capture + const micCtx = new AudioContext({ sampleRate: MIC_SR }); + const source = micCtx.createMediaStreamSource(micStream); + + await micCtx.audioWorklet.addModule("/static/processor.js"); + workletNode = new AudioWorkletNode(micCtx, "pcm-processor"); + source.connect(workletNode); + + workletNode.port.onmessage = (e) => { + if (ws && ws.readyState === WebSocket.OPEN) { + ws.send(e.data); + + // Client-side barge-in: detect mic energy while playing + if (isPlaying) { + const samples = new Int16Array(e.data); + let sum = 0; + for (let i = 0; i < samples.length; i++) { + const s = samples[i] / 32768; + sum += s * s; + } + const rms = Math.sqrt(sum / samples.length); + + if (rms > BARGE_IN_THRESHOLD) { + bargeInCount++; + if (bargeInCount >= BARGE_IN_FRAMES) { + // User is speaking over the assistant - interrupt + stopPlayback(); + finalizeAssistantMessage(); + ws.send(JSON.stringify({ type: "interrupt" })); + isPlaying = false; + bargeInCount = 0; + } + } else { + bargeInCount = 0; + } + } + } + }; + + // Store for cleanup + workletNode._micCtx = micCtx; + + micActive = true; + micBtn.classList.add("active"); + + // Connect WebSocket if not already + if (!ws || ws.readyState !== WebSocket.OPEN) { + connectWS(); + } + } catch (err) { + console.error("Mic access failed:", err); + alert("Could not access microphone. Please allow mic permissions."); + } +} + +function stopMic() { + if (workletNode) { + workletNode.disconnect(); + if (workletNode._micCtx) { + workletNode._micCtx.close(); + } + workletNode = null; + } + if (micStream) { + micStream.getTracks().forEach((t) => t.stop()); + micStream = null; + } + micActive = false; + micBtn.classList.remove("active"); +} + +// --- Voice Selection --- + +async function applyVoice() { + const voice = document.getElementById("voice-select").value; + const statusEl = document.getElementById("voice-status"); + + const formData = new FormData(); + formData.append("voice", voice); + formData.append("lang", "a"); + + statusEl.textContent = "Applying..."; + try { + const resp = await fetch("/api/set-voice", { + method: "POST", + body: formData, + }); + const data = await resp.json(); + if (data.status === "ok") { + statusEl.textContent = "Voice: " + voice; + } else { + statusEl.textContent = "Failed."; + } + } catch (err) { + statusEl.textContent = "Error: " + err.message; + } +} + +// Expose to HTML onclick +window.toggleMic = toggleMic; +window.applyVoice = applyVoice; diff --git a/static/index.html b/static/index.html new file mode 100644 index 0000000..83a88ca --- /dev/null +++ b/static/index.html @@ -0,0 +1,49 @@ + + + + + + Voice Chat + + + +
+

Voice Chat

+ Disconnected +
+ +
+ +
+ Voice Settings +
+ + + +
+
+ +
+ +
+ + + + diff --git a/static/processor.js b/static/processor.js new file mode 100644 index 0000000..283dee0 --- /dev/null +++ b/static/processor.js @@ -0,0 +1,42 @@ +/** + * AudioWorkletProcessor that collects 512-sample chunks of PCM audio + * and posts them to the main thread for WebSocket transmission. + */ +class PCMProcessor extends AudioWorkletProcessor { + constructor() { + super(); + this.buffer = new Float32Array(0); + this.chunkSize = 512; // 512 samples at 16kHz = 32ms + } + + process(inputs) { + const input = inputs[0]; + if (!input || !input[0]) return true; + + const channelData = input[0]; // mono + + // Append to buffer + const newBuffer = new Float32Array(this.buffer.length + channelData.length); + newBuffer.set(this.buffer); + newBuffer.set(channelData, this.buffer.length); + this.buffer = newBuffer; + + // Send complete chunks + while (this.buffer.length >= this.chunkSize) { + const chunk = this.buffer.slice(0, this.chunkSize); + this.buffer = this.buffer.slice(this.chunkSize); + + // Convert float32 to int16 for transmission + const int16 = new Int16Array(chunk.length); + for (let i = 0; i < chunk.length; i++) { + const s = Math.max(-1, Math.min(1, chunk[i])); + int16[i] = s < 0 ? s * 0x8000 : s * 0x7fff; + } + this.port.postMessage(int16.buffer, [int16.buffer]); + } + + return true; + } +} + +registerProcessor("pcm-processor", PCMProcessor); diff --git a/static/style.css b/static/style.css new file mode 100644 index 0000000..dde72f1 --- /dev/null +++ b/static/style.css @@ -0,0 +1,185 @@ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + background: #0f0f0f; + color: #e0e0e0; + height: 100vh; + display: flex; + flex-direction: column; +} + +header { + padding: 16px 24px; + border-bottom: 1px solid #222; + display: flex; + align-items: center; + justify-content: space-between; +} + +header h1 { + font-size: 18px; + font-weight: 600; + color: #fff; +} + +#status-badge { + padding: 4px 12px; + border-radius: 12px; + font-size: 13px; + font-weight: 500; + background: #1a1a2e; + color: #888; + transition: all 0.3s; +} + +#status-badge.listening { + background: #0a2a1a; + color: #4ade80; +} + +#status-badge.thinking { + background: #2a1a0a; + color: #fbbf24; +} + +#status-badge.speaking { + background: #1a0a2a; + color: #a78bfa; +} + +#chat-area { + flex: 1; + overflow-y: auto; + padding: 24px; + display: flex; + flex-direction: column; + gap: 12px; +} + +.message { + max-width: 70%; + padding: 10px 16px; + border-radius: 16px; + font-size: 15px; + line-height: 1.5; + word-wrap: break-word; +} + +.message.user { + align-self: flex-end; + background: #1d4ed8; + color: #fff; + border-bottom-right-radius: 4px; +} + +.message.assistant { + align-self: flex-start; + background: #1e1e1e; + color: #e0e0e0; + border-bottom-left-radius: 4px; +} + +#controls { + padding: 16px 24px; + border-top: 1px solid #222; + display: flex; + align-items: center; + gap: 16px; +} + +#mic-btn { + width: 56px; + height: 56px; + border-radius: 50%; + border: 2px solid #333; + background: #1a1a1a; + color: #e0e0e0; + font-size: 24px; + cursor: pointer; + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; +} + +#mic-btn:hover { + border-color: #555; + background: #222; +} + +#mic-btn.active { + border-color: #ef4444; + background: #2a0a0a; + color: #ef4444; + animation: pulse 1.5s infinite; +} + +@keyframes pulse { + 0%, 100% { box-shadow: 0 0 0 0 rgba(239, 68, 68, 0.3); } + 50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); } +} + +/* Voice clone panel */ +#voice-panel { + padding: 12px 24px; + border-top: 1px solid #222; + background: #0a0a0a; +} + +#voice-panel summary { + cursor: pointer; + font-size: 13px; + color: #888; + user-select: none; +} + +#voice-panel .panel-content { + margin-top: 12px; + display: flex; + gap: 12px; + align-items: flex-end; + flex-wrap: wrap; +} + +#voice-panel label { + font-size: 13px; + color: #aaa; + display: flex; + flex-direction: column; + gap: 4px; +} + +#voice-panel input[type="file"], +#voice-panel input[type="text"] { + background: #1a1a1a; + border: 1px solid #333; + border-radius: 6px; + padding: 6px 10px; + color: #e0e0e0; + font-size: 13px; +} + +#upload-btn { + padding: 6px 16px; + border-radius: 6px; + border: 1px solid #333; + background: #1a1a1a; + color: #e0e0e0; + font-size: 13px; + cursor: pointer; +} + +#upload-btn:hover { + background: #222; +} + +#upload-status { + font-size: 12px; + color: #888; + margin-left: 8px; +}