Files
live-voice-chat/tests/component/test_13_vae_decode.py
T
2026-04-16 10:00:37 -04:00

105 lines
3.4 KiB
Python

"""Smoke test: VAE decoder under GGUF pipeline.
Builds the pipeline, runs all encoders, initializes the scheduler, executes
one DIT denoising step, then decodes the resulting latents back to pixel
frames via the VAE decoder. Validates the full encode→denoise→decode path.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_13_vae_decode
"""
from __future__ import annotations
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_13")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info for a short clip
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17, # 1 second at 16fps + 1
},
)
runner.input_info = pipe._input_info_template
# 1. Run all encoders (T5 + CLIP + VAE)
log.info("Running all input encoders (T5 + CLIP + VAE)...")
runner.inputs = runner.run_input_encoder()
log.info("Encoder outputs ready.")
# 2. Initialize run (sets up scheduler, creates noise latents)
log.info("Initializing run (scheduler.prepare)...")
runner.init_run()
log.info("Initial latents: shape=%s dtype=%s",
runner.model.scheduler.latents.shape,
runner.model.scheduler.latents.dtype)
# 3. Single DIT step (so we have realistic latents to decode)
log.info("Running single DIT step...")
runner.model.scheduler.step_pre(step_index=0)
runner.model.infer(runner.inputs)
runner.model.scheduler.step_post()
latents = runner.model.scheduler.latents
log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype)
# 4. VAE decode
log.info("Running VAE decoder...")
video_out = runner.run_vae_decoder(latents)
log.info("VAE decoder output type: %s", type(video_out))
if isinstance(video_out, torch.Tensor):
log.info("video_out: shape=%s dtype=%s device=%s",
video_out.shape, video_out.dtype, video_out.device)
elif isinstance(video_out, list):
log.info("video_out: list of %d items", len(video_out))
if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor):
log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype)
else:
log.info("video_out: %s", video_out)
log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()