test passing

This commit is contained in:
2026-04-12 16:38:44 -04:00
parent fcf0be38bc
commit 56923ff424
5 changed files with 435 additions and 4 deletions
+107
View File
@@ -0,0 +1,107 @@
"""Smoke test: single DIT denoising step with GGUF weights.
Builds the pipeline, runs all encoders, initializes the scheduler, then
executes exactly one DIT forward pass (step_pre → infer → step_post).
This isolates the GGUF fp16 DIT from the rest of the pipeline.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_12_dit_single_step
"""
from __future__ import annotations
import copy
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_12")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info for a short clip
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17, # 1 second at 16fps + 1
},
)
runner.input_info = pipe._input_info_template
# 1. Run all encoders (T5 + CLIP + VAE)
log.info("Running all input encoders (T5 + CLIP + VAE)...")
runner.inputs = runner.run_input_encoder()
log.info("Encoder outputs ready.")
for k, v in runner.inputs.items():
if isinstance(v, torch.Tensor):
log.info(" inputs[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
elif isinstance(v, dict):
for k2, v2 in v.items():
if isinstance(v2, torch.Tensor):
log.info(" inputs[%s][%s]: shape=%s dtype=%s", k, k2, v2.shape, v2.dtype)
# 2. Initialize run (sets up scheduler, creates noise latents)
log.info("Initializing run (scheduler.prepare)...")
runner.init_run()
latents = runner.model.scheduler.latents
log.info("Initial latents: shape=%s dtype=%s", latents.shape, latents.dtype)
# 3. Single DIT step
log.info("Running single DIT step (step_pre → infer → step_post)...")
runner.model.scheduler.step_pre(step_index=0)
runner.model.infer(runner.inputs)
runner.model.scheduler.step_post()
latents_after = runner.model.scheduler.latents
log.info("Latents after step: shape=%s dtype=%s", latents_after.shape, latents_after.dtype)
# Verify latents changed (denoising did something)
assert not torch.equal(latents, latents_after), "Latents unchanged after DIT step"
log.info("PASS: DIT single step completed, latents updated.")
log.info("PASS: DIT forward pass succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()