113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
"""Smoke test: image reading + VAE encoder (+ CLIP if enabled) under GGUF pipeline.
|
|
|
|
Builds the Wan22Pipeline, loads a sample avatar, reads the image input,
|
|
runs the CLIP image encoder (if use_image_encoder is true in the config),
|
|
and runs the VAE encoder. Validates outputs under DTYPE=FP16.
|
|
|
|
Note: The GGUF distill config sets use_image_encoder=false, so CLIP is
|
|
skipped by default. The VAE encoder is always exercised.
|
|
|
|
Run:
|
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
|
python -m tests.component.test_11_image_encode
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
from tests.component._common import ensure_sample_avatar, get_logger
|
|
|
|
log = get_logger("test_11")
|
|
|
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
|
|
|
if DIT_QUANT.startswith("gguf-"):
|
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
|
else:
|
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
|
|
|
|
|
def run():
|
|
try:
|
|
from server.video_models.wan22 import Wan22Pipeline
|
|
except ImportError as e:
|
|
log.error("Import failed: %s", e)
|
|
sys.exit(0)
|
|
|
|
avatar = ensure_sample_avatar()
|
|
log.info("Avatar: %s", avatar)
|
|
|
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
|
pipe = Wan22Pipeline(
|
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
|
dit_repo=DIT_REPO,
|
|
config_json=CONFIG_JSON,
|
|
model_cls="wan2.2_moe_distill",
|
|
resolution=480,
|
|
fps=16,
|
|
dit_quant_scheme=DIT_QUANT,
|
|
t5_quantized=True,
|
|
)
|
|
log.info("Pipeline ready.")
|
|
|
|
runner = pipe._runner
|
|
|
|
# Set up input_info so runner methods can access it
|
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
|
update_input_info_from_dict(
|
|
pipe._input_info_template,
|
|
{
|
|
"seed": 42,
|
|
"prompt": "a person looking at the camera, natural lighting",
|
|
"negative_prompt": "",
|
|
"image_path": avatar,
|
|
"target_video_length": 17,
|
|
},
|
|
)
|
|
runner.input_info = pipe._input_info_template
|
|
|
|
# 1. Load image
|
|
log.info("Reading image input...")
|
|
img, img_ori = runner.read_image_input(avatar)
|
|
log.info("img: shape=%s dtype=%s device=%s", img.shape, img.dtype, img.device)
|
|
|
|
# 2. CLIP image encoder (only if enabled in config)
|
|
use_clip = runner.config.get("use_image_encoder", True)
|
|
if use_clip:
|
|
log.info("Running CLIP image encoder...")
|
|
clip_out = runner.run_image_encoder(img)
|
|
log.info("clip_out: shape=%s dtype=%s device=%s", clip_out.shape, clip_out.dtype, clip_out.device)
|
|
assert isinstance(clip_out, torch.Tensor), f"Expected tensor, got {type(clip_out)}"
|
|
log.info("PASS: CLIP image encoder succeeded.")
|
|
else:
|
|
log.info("CLIP image encoder disabled (use_image_encoder=false) — skipping.")
|
|
|
|
# 3. VAE encoder
|
|
vae_input = img_ori if runner.vae_encoder_need_img_original else img
|
|
log.info("Running VAE encoder (using %s)...",
|
|
"img_ori" if runner.vae_encoder_need_img_original else "img tensor")
|
|
vae_out, latent_shape = runner.run_vae_encoder(vae_input)
|
|
log.info("latent_shape: %s", latent_shape)
|
|
if isinstance(vae_out, torch.Tensor):
|
|
log.info("vae_out: shape=%s dtype=%s device=%s", vae_out.shape, vae_out.dtype, vae_out.device)
|
|
elif isinstance(vae_out, dict):
|
|
for k, v in vae_out.items():
|
|
if isinstance(v, torch.Tensor):
|
|
log.info("vae_out[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
|
|
else:
|
|
log.info("vae_out[%s]: type=%s", k, type(v))
|
|
else:
|
|
log.info("vae_out: type=%s", type(vae_out))
|
|
log.info("PASS: VAE encoder succeeded.")
|
|
|
|
log.info("PASS: All image encoding stages completed under %s pipeline.", DIT_QUANT)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|