initial commit
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.params import Form
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from server.audio_utils import pcm_bytes_to_float32
|
||||
from server.models import ModelManager
|
||||
from server.pipeline import ConversationSession
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
|
||||
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||
|
||||
model_mgr = ModelManager()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
log.info("Starting model loading...")
|
||||
model_mgr.load_all()
|
||||
log.info("Server ready.")
|
||||
yield
|
||||
log.info("Shutting down.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index():
|
||||
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.post("/api/set-voice")
|
||||
async def set_voice(voice: str = Form(...), lang: str = Form("a")):
|
||||
"""Change the TTS voice."""
|
||||
model_mgr.tts_engine.set_voice(voice, lang)
|
||||
return {"status": "ok", "voice": voice}
|
||||
|
||||
|
||||
@app.websocket("/ws/chat")
|
||||
async def websocket_chat(ws: WebSocket):
|
||||
await ws.accept()
|
||||
log.info("WebSocket client connected.")
|
||||
|
||||
async def send_json(data: dict):
|
||||
await ws.send_text(json.dumps(data))
|
||||
|
||||
async def send_bytes(data: bytes):
|
||||
await ws.send_bytes(data)
|
||||
|
||||
session = ConversationSession(model_mgr, send_json, send_bytes)
|
||||
await session.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await ws.receive()
|
||||
|
||||
if "bytes" in message:
|
||||
pcm_data = message["bytes"]
|
||||
chunk = pcm_bytes_to_float32(pcm_data)
|
||||
await session.handle_audio_chunk(chunk)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
msg = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if msg.get("type") == "interrupt":
|
||||
await session.interrupt()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
log.info("WebSocket client disconnected.")
|
||||
except Exception:
|
||||
log.exception("WebSocket error")
|
||||
finally:
|
||||
await session.stop()
|
||||
Reference in New Issue
Block a user