test passing

This commit is contained in:
2026-04-12 16:38:44 -04:00
parent fcf0be38bc
commit 56923ff424
5 changed files with 435 additions and 4 deletions
+9 -3
View File
@@ -11,9 +11,15 @@
"target_width": 480, "target_width": 480,
"fps": 16, "fps": 16,
"self_attn_1_type": "flash_attn3", "_comment_attn": "flash_attn3/sageattn3 aren't installed (no Blackwell-ready pre-built wheels). Use PyTorch SDPA which works on SM120.",
"cross_attn_1_type": "flash_attn3", "self_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "flash_attn3", "cross_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "torch_sdpa",
"_comment_modulate": "Triton fuse_scale_shift_kernel segfaults during JIT compile on Blackwell SM120 (triton 3.4 + cu128). Use the PyTorch modulate fallback until the Triton issue is resolved.",
"modulate_type": "torch",
"_comment_rope": "flashinfer not installed; fall back to PyTorch rope.",
"rope_type": "torch",
"sample_guide_scale": [3.5, 3.5], "sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0, "sample_shift": 5.0,
+98 -1
View File
@@ -220,6 +220,8 @@ class Wan22Pipeline:
GET_DTYPE.cache_clear() GET_DTYPE.cache_clear()
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE()) log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
self._patch_t5_dtype_for_gguf() self._patch_t5_dtype_for_gguf()
self._patch_vae_dtype_for_gguf()
self._patch_dit_fp32_weights_for_gguf()
# --- GGUF dtype compatibility patch ---------------------------------------- # --- GGUF dtype compatibility patch ----------------------------------------
@@ -240,6 +242,7 @@ class Wan22Pipeline:
orig_run_text_encoder = runner.run_text_encoder.__func__ orig_run_text_encoder = runner.run_text_encoder.__func__
def bf16_text_encoder(self_runner, *args, **kwargs): def bf16_text_encoder(self_runner, *args, **kwargs):
import torch
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights. # Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
os.environ["DTYPE"] = "BF16" os.environ["DTYPE"] = "BF16"
GET_DTYPE.cache_clear() GET_DTYPE.cache_clear()
@@ -251,11 +254,105 @@ class Wan22Pipeline:
os.environ["DTYPE"] = "FP16" os.environ["DTYPE"] = "FP16"
GET_DTYPE.cache_clear() GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear() GET_SENSITIVE_DTYPE.cache_clear()
return result # Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
def _to_fp16(x):
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
return x.to(torch.float16)
if isinstance(x, list):
return [_to_fp16(v) for v in x]
if isinstance(x, tuple):
return tuple(_to_fp16(v) for v in x)
if isinstance(x, dict):
return {k: _to_fp16(v) for k, v in x.items()}
return x
return _to_fp16(result)
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner) runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.") log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
def _patch_vae_dtype_for_gguf(self) -> None:
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
VAE is a plain float model (not quantized), simply converting its
weights to fp16 avoids both input-vs-weight mismatches and the need
for any runtime dtype juggling.
"""
import torch
runner = self._runner
for name in ("vae_encoder", "vae_decoder"):
mod = getattr(runner, name, None)
if mod is None:
continue
inner = getattr(mod, "model", mod)
if hasattr(inner, "to"):
inner.to(dtype=torch.float16)
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
# the first op upcasts fp16 latents back to fp32/bf16.
for attr in ("mean", "inv_std"):
t = getattr(mod, attr, None)
if isinstance(t, torch.Tensor):
setattr(mod, attr, t.to(torch.float16))
scale = getattr(mod, "scale", None)
if isinstance(scale, list):
mod.scale = [
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
for t in scale
]
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
def _patch_dit_fp32_weights_for_gguf(self) -> None:
"""Cast leftover fp32 DIT pre/post weights to fp16.
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
up loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
runs uniformly in fp16.
"""
import torch
runner = self._runner
models = getattr(runner.model, "model", None)
if models is None:
return
if not isinstance(models, (list, tuple)):
models = [models]
n_cast = 0
for m in models:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
# Preserve pinned-memory status on pin_* tensors so
# move_attr_to_cuda's non-blocking H2D copy is safe.
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
# --- Weight provisioning ------------------------------------------------- # --- Weight provisioning -------------------------------------------------
@staticmethod @staticmethod
+112
View File
@@ -0,0 +1,112 @@
"""Smoke test: image reading + VAE encoder (+ CLIP if enabled) under GGUF pipeline.
Builds the Wan22Pipeline, loads a sample avatar, reads the image input,
runs the CLIP image encoder (if use_image_encoder is true in the config),
and runs the VAE encoder. Validates outputs under DTYPE=FP16.
Note: The GGUF distill config sets use_image_encoder=false, so CLIP is
skipped by default. The VAE encoder is always exercised.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_11_image_encode
"""
from __future__ import annotations
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_11")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info so runner methods can access it
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17,
},
)
runner.input_info = pipe._input_info_template
# 1. Load image
log.info("Reading image input...")
img, img_ori = runner.read_image_input(avatar)
log.info("img: shape=%s dtype=%s device=%s", img.shape, img.dtype, img.device)
# 2. CLIP image encoder (only if enabled in config)
use_clip = runner.config.get("use_image_encoder", True)
if use_clip:
log.info("Running CLIP image encoder...")
clip_out = runner.run_image_encoder(img)
log.info("clip_out: shape=%s dtype=%s device=%s", clip_out.shape, clip_out.dtype, clip_out.device)
assert isinstance(clip_out, torch.Tensor), f"Expected tensor, got {type(clip_out)}"
log.info("PASS: CLIP image encoder succeeded.")
else:
log.info("CLIP image encoder disabled (use_image_encoder=false) — skipping.")
# 3. VAE encoder
vae_input = img_ori if runner.vae_encoder_need_img_original else img
log.info("Running VAE encoder (using %s)...",
"img_ori" if runner.vae_encoder_need_img_original else "img tensor")
vae_out, latent_shape = runner.run_vae_encoder(vae_input)
log.info("latent_shape: %s", latent_shape)
if isinstance(vae_out, torch.Tensor):
log.info("vae_out: shape=%s dtype=%s device=%s", vae_out.shape, vae_out.dtype, vae_out.device)
elif isinstance(vae_out, dict):
for k, v in vae_out.items():
if isinstance(v, torch.Tensor):
log.info("vae_out[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
else:
log.info("vae_out[%s]: type=%s", k, type(v))
else:
log.info("vae_out: type=%s", type(vae_out))
log.info("PASS: VAE encoder succeeded.")
log.info("PASS: All image encoding stages completed under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
+107
View File
@@ -0,0 +1,107 @@
"""Smoke test: single DIT denoising step with GGUF weights.
Builds the pipeline, runs all encoders, initializes the scheduler, then
executes exactly one DIT forward pass (step_pre → infer → step_post).
This isolates the GGUF fp16 DIT from the rest of the pipeline.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_12_dit_single_step
"""
from __future__ import annotations
import copy
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_12")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info for a short clip
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17, # 1 second at 16fps + 1
},
)
runner.input_info = pipe._input_info_template
# 1. Run all encoders (T5 + CLIP + VAE)
log.info("Running all input encoders (T5 + CLIP + VAE)...")
runner.inputs = runner.run_input_encoder()
log.info("Encoder outputs ready.")
for k, v in runner.inputs.items():
if isinstance(v, torch.Tensor):
log.info(" inputs[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
elif isinstance(v, dict):
for k2, v2 in v.items():
if isinstance(v2, torch.Tensor):
log.info(" inputs[%s][%s]: shape=%s dtype=%s", k, k2, v2.shape, v2.dtype)
# 2. Initialize run (sets up scheduler, creates noise latents)
log.info("Initializing run (scheduler.prepare)...")
runner.init_run()
latents = runner.model.scheduler.latents
log.info("Initial latents: shape=%s dtype=%s", latents.shape, latents.dtype)
# 3. Single DIT step
log.info("Running single DIT step (step_pre → infer → step_post)...")
runner.model.scheduler.step_pre(step_index=0)
runner.model.infer(runner.inputs)
runner.model.scheduler.step_post()
latents_after = runner.model.scheduler.latents
log.info("Latents after step: shape=%s dtype=%s", latents_after.shape, latents_after.dtype)
# Verify latents changed (denoising did something)
assert not torch.equal(latents, latents_after), "Latents unchanged after DIT step"
log.info("PASS: DIT single step completed, latents updated.")
log.info("PASS: DIT forward pass succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
+109
View File
@@ -0,0 +1,109 @@
"""Smoke test: VAE decoder under GGUF pipeline.
Builds the pipeline, runs all encoders, initializes the scheduler, executes
one DIT denoising step, then decodes the resulting latents back to pixel
frames via the VAE decoder. Validates the full encode→denoise→decode path.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_13_vae_decode
"""
from __future__ import annotations
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_13")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info for a short clip
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17, # 1 second at 16fps + 1
},
)
runner.input_info = pipe._input_info_template
# 1. Run all encoders (T5 + CLIP + VAE)
log.info("Running all input encoders (T5 + CLIP + VAE)...")
runner.inputs = runner.run_input_encoder()
log.info("Encoder outputs ready.")
# 2. Initialize run (sets up scheduler, creates noise latents)
log.info("Initializing run (scheduler.prepare)...")
runner.init_run()
log.info("Initial latents: shape=%s dtype=%s",
runner.model.scheduler.latents.shape,
runner.model.scheduler.latents.dtype)
# 3. Single DIT step (so we have realistic latents to decode)
log.info("Running single DIT step...")
runner.model.scheduler.step_pre(step_index=0)
runner.model.infer(runner.inputs)
runner.model.scheduler.step_post()
latents = runner.model.scheduler.latents
log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype)
# 4. VAE decode
log.info("Running VAE decoder...")
video_out = runner.run_vae_decoder(latents)
log.info("VAE decoder output type: %s", type(video_out))
if isinstance(video_out, torch.Tensor):
log.info("video_out: shape=%s dtype=%s device=%s",
video_out.shape, video_out.dtype, video_out.device)
elif isinstance(video_out, list):
log.info("video_out: list of %d items", len(video_out))
if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor):
log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype)
else:
log.info("video_out: %s", video_out)
log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()