test passing
This commit is contained in:
@@ -0,0 +1,109 @@
|
||||
"""Smoke test: VAE decoder under GGUF pipeline.
|
||||
|
||||
Builds the pipeline, runs all encoders, initializes the scheduler, executes
|
||||
one DIT denoising step, then decodes the resulting latents back to pixel
|
||||
frames via the VAE decoder. Validates the full encode→denoise→decode path.
|
||||
|
||||
Run:
|
||||
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||
python -m tests.component.test_13_vae_decode
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||
|
||||
log = get_logger("test_13")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
except ImportError as e:
|
||||
log.error("Import failed: %s", e)
|
||||
sys.exit(0)
|
||||
|
||||
avatar = ensure_sample_avatar()
|
||||
log.info("Avatar: %s", avatar)
|
||||
|
||||
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
t5_quantized=True,
|
||||
)
|
||||
log.info("Pipeline ready.")
|
||||
|
||||
runner = pipe._runner
|
||||
|
||||
# Set up input_info for a short clip
|
||||
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||
update_input_info_from_dict(
|
||||
pipe._input_info_template,
|
||||
{
|
||||
"seed": 42,
|
||||
"prompt": "a person looking at the camera, natural lighting",
|
||||
"negative_prompt": "",
|
||||
"image_path": avatar,
|
||||
"target_video_length": 17, # 1 second at 16fps + 1
|
||||
},
|
||||
)
|
||||
runner.input_info = pipe._input_info_template
|
||||
|
||||
# 1. Run all encoders (T5 + CLIP + VAE)
|
||||
log.info("Running all input encoders (T5 + CLIP + VAE)...")
|
||||
runner.inputs = runner.run_input_encoder()
|
||||
log.info("Encoder outputs ready.")
|
||||
|
||||
# 2. Initialize run (sets up scheduler, creates noise latents)
|
||||
log.info("Initializing run (scheduler.prepare)...")
|
||||
runner.init_run()
|
||||
log.info("Initial latents: shape=%s dtype=%s",
|
||||
runner.model.scheduler.latents.shape,
|
||||
runner.model.scheduler.latents.dtype)
|
||||
|
||||
# 3. Single DIT step (so we have realistic latents to decode)
|
||||
log.info("Running single DIT step...")
|
||||
runner.model.scheduler.step_pre(step_index=0)
|
||||
runner.model.infer(runner.inputs)
|
||||
runner.model.scheduler.step_post()
|
||||
latents = runner.model.scheduler.latents
|
||||
log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype)
|
||||
|
||||
# 4. VAE decode
|
||||
log.info("Running VAE decoder...")
|
||||
video_out = runner.run_vae_decoder(latents)
|
||||
log.info("VAE decoder output type: %s", type(video_out))
|
||||
if isinstance(video_out, torch.Tensor):
|
||||
log.info("video_out: shape=%s dtype=%s device=%s",
|
||||
video_out.shape, video_out.dtype, video_out.device)
|
||||
elif isinstance(video_out, list):
|
||||
log.info("video_out: list of %d items", len(video_out))
|
||||
if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor):
|
||||
log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype)
|
||||
else:
|
||||
log.info("video_out: %s", video_out)
|
||||
|
||||
log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user