diff --git a/configs/lightx2v/wan22_i2v_gguf_distill.json b/configs/lightx2v/wan22_i2v_gguf_distill.json index f8730cf..6be131d 100644 --- a/configs/lightx2v/wan22_i2v_gguf_distill.json +++ b/configs/lightx2v/wan22_i2v_gguf_distill.json @@ -11,9 +11,15 @@ "target_width": 480, "fps": 16, - "self_attn_1_type": "flash_attn3", - "cross_attn_1_type": "flash_attn3", - "cross_attn_2_type": "flash_attn3", + "_comment_attn": "flash_attn3/sageattn3 aren't installed (no Blackwell-ready pre-built wheels). Use PyTorch SDPA which works on SM120.", + "self_attn_1_type": "torch_sdpa", + "cross_attn_1_type": "torch_sdpa", + "cross_attn_2_type": "torch_sdpa", + + "_comment_modulate": "Triton fuse_scale_shift_kernel segfaults during JIT compile on Blackwell SM120 (triton 3.4 + cu128). Use the PyTorch modulate fallback until the Triton issue is resolved.", + "modulate_type": "torch", + "_comment_rope": "flashinfer not installed; fall back to PyTorch rope.", + "rope_type": "torch", "sample_guide_scale": [3.5, 3.5], "sample_shift": 5.0, diff --git a/server/video_models/wan22.py b/server/video_models/wan22.py index 3ecf5bd..15d8e85 100644 --- a/server/video_models/wan22.py +++ b/server/video_models/wan22.py @@ -220,6 +220,8 @@ class Wan22Pipeline: GET_DTYPE.cache_clear() log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE()) self._patch_t5_dtype_for_gguf() + self._patch_vae_dtype_for_gguf() + self._patch_dit_fp32_weights_for_gguf() # --- GGUF dtype compatibility patch ---------------------------------------- @@ -240,6 +242,7 @@ class Wan22Pipeline: orig_run_text_encoder = runner.run_text_encoder.__func__ def bf16_text_encoder(self_runner, *args, **kwargs): + import torch # Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights. os.environ["DTYPE"] = "BF16" GET_DTYPE.cache_clear() @@ -251,11 +254,105 @@ class Wan22Pipeline: os.environ["DTYPE"] = "FP16" GET_DTYPE.cache_clear() GET_SENSITIVE_DTYPE.cache_clear() - return result + # Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype. + def _to_fp16(x): + if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16: + return x.to(torch.float16) + if isinstance(x, list): + return [_to_fp16(v) for v in x] + if isinstance(x, tuple): + return tuple(_to_fp16(v) for v in x) + if isinstance(x, dict): + return {k: _to_fp16(v) for k, v in x.items()} + return x + return _to_fp16(result) runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner) log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.") + def _patch_vae_dtype_for_gguf(self) -> None: + """Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype. + + The VAE weights load as bf16 (the default). Under GGUF the DIT runs in + fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which + under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the + VAE is a plain float model (not quantized), simply converting its + weights to fp16 avoids both input-vs-weight mismatches and the need + for any runtime dtype juggling. + """ + import torch + + runner = self._runner + for name in ("vae_encoder", "vae_decoder"): + mod = getattr(runner, name, None) + if mod is None: + continue + inner = getattr(mod, "model", mod) + if hasattr(inner, "to"): + inner.to(dtype=torch.float16) + # The outer WanVAE wrapper also holds mean/inv_std/scale tensors + # used by encode/decode (z = z/inv_std + mean). Cast them too, or + # the first op upcasts fp16 latents back to fp32/bf16. + for attr in ("mean", "inv_std"): + t = getattr(mod, attr, None) + if isinstance(t, torch.Tensor): + setattr(mod, attr, t.to(torch.float16)) + scale = getattr(mod, "scale", None) + if isinstance(scale, list): + mod.scale = [ + t.to(torch.float16) if isinstance(t, torch.Tensor) else t + for t in scale + ] + log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.") + + def _patch_dit_fp32_weights_for_gguf(self) -> None: + """Cast leftover fp32 DIT pre/post weights to fp16. + + GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful + of non-quantised weights (notably ``patch_embedding.pin_weight``) end + up loaded as fp32. That breaks the first conv in the DIT forward pass + (fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT + runs uniformly in fp16. + """ + import torch + + runner = self._runner + models = getattr(runner.model, "model", None) + if models is None: + return + if not isinstance(models, (list, tuple)): + models = [models] + + n_cast = 0 + for m in models: + for weights_attr in ("pre_weight", "post_weight"): + w = getattr(m, weights_attr, None) + if w is None: + continue + for sub_name in dir(w): + if sub_name.startswith("_"): + continue + try: + sub = getattr(w, sub_name) + except Exception: + continue + if sub is None: + continue + for t_name in ("weight", "bias", "pin_weight", "pin_bias"): + t = getattr(sub, t_name, None) + if isinstance(t, torch.Tensor) and t.dtype == torch.float32: + casted = t.to(torch.float16) + # Preserve pinned-memory status on pin_* tensors so + # move_attr_to_cuda's non-blocking H2D copy is safe. + if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned(): + try: + casted = casted.pin_memory() + except RuntimeError: + pass + setattr(sub, t_name, casted) + n_cast += 1 + log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast) + # --- Weight provisioning ------------------------------------------------- @staticmethod diff --git a/tests/component/test_11_image_encode.py b/tests/component/test_11_image_encode.py new file mode 100644 index 0000000..028decd --- /dev/null +++ b/tests/component/test_11_image_encode.py @@ -0,0 +1,112 @@ +"""Smoke test: image reading + VAE encoder (+ CLIP if enabled) under GGUF pipeline. + +Builds the Wan22Pipeline, loads a sample avatar, reads the image input, +runs the CLIP image encoder (if use_image_encoder is true in the config), +and runs the VAE encoder. Validates outputs under DTYPE=FP16. + +Note: The GGUF distill config sets use_image_encoder=false, so CLIP is +skipped by default. The VAE encoder is always exercised. + +Run: + docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \ + python -m tests.component.test_11_image_encode +""" +from __future__ import annotations + +import os +import sys + +import torch + +from tests.component._common import ensure_sample_avatar, get_logger + +log = get_logger("test_11") + +DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M") + +if DIT_QUANT.startswith("gguf-"): + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json" + DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF" +else: + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + DIT_REPO = "lightx2v/Wan2.2-Distill-Models" + + +def run(): + try: + from server.video_models.wan22 import Wan22Pipeline + except ImportError as e: + log.error("Import failed: %s", e) + sys.exit(0) + + avatar = ensure_sample_avatar() + log.info("Avatar: %s", avatar) + + log.info("Building pipeline (quant=%s)...", DIT_QUANT) + pipe = Wan22Pipeline( + base_repo="Wan-AI/Wan2.2-I2V-A14B", + dit_repo=DIT_REPO, + config_json=CONFIG_JSON, + model_cls="wan2.2_moe_distill", + resolution=480, + fps=16, + dit_quant_scheme=DIT_QUANT, + t5_quantized=True, + ) + log.info("Pipeline ready.") + + runner = pipe._runner + + # Set up input_info so runner methods can access it + from lightx2v.utils.input_info import update_input_info_from_dict + update_input_info_from_dict( + pipe._input_info_template, + { + "seed": 42, + "prompt": "a person looking at the camera, natural lighting", + "negative_prompt": "", + "image_path": avatar, + "target_video_length": 17, + }, + ) + runner.input_info = pipe._input_info_template + + # 1. Load image + log.info("Reading image input...") + img, img_ori = runner.read_image_input(avatar) + log.info("img: shape=%s dtype=%s device=%s", img.shape, img.dtype, img.device) + + # 2. CLIP image encoder (only if enabled in config) + use_clip = runner.config.get("use_image_encoder", True) + if use_clip: + log.info("Running CLIP image encoder...") + clip_out = runner.run_image_encoder(img) + log.info("clip_out: shape=%s dtype=%s device=%s", clip_out.shape, clip_out.dtype, clip_out.device) + assert isinstance(clip_out, torch.Tensor), f"Expected tensor, got {type(clip_out)}" + log.info("PASS: CLIP image encoder succeeded.") + else: + log.info("CLIP image encoder disabled (use_image_encoder=false) — skipping.") + + # 3. VAE encoder + vae_input = img_ori if runner.vae_encoder_need_img_original else img + log.info("Running VAE encoder (using %s)...", + "img_ori" if runner.vae_encoder_need_img_original else "img tensor") + vae_out, latent_shape = runner.run_vae_encoder(vae_input) + log.info("latent_shape: %s", latent_shape) + if isinstance(vae_out, torch.Tensor): + log.info("vae_out: shape=%s dtype=%s device=%s", vae_out.shape, vae_out.dtype, vae_out.device) + elif isinstance(vae_out, dict): + for k, v in vae_out.items(): + if isinstance(v, torch.Tensor): + log.info("vae_out[%s]: shape=%s dtype=%s", k, v.shape, v.dtype) + else: + log.info("vae_out[%s]: type=%s", k, type(v)) + else: + log.info("vae_out: type=%s", type(vae_out)) + log.info("PASS: VAE encoder succeeded.") + + log.info("PASS: All image encoding stages completed under %s pipeline.", DIT_QUANT) + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_12_dit_single_step.py b/tests/component/test_12_dit_single_step.py new file mode 100644 index 0000000..e8bd552 --- /dev/null +++ b/tests/component/test_12_dit_single_step.py @@ -0,0 +1,107 @@ +"""Smoke test: single DIT denoising step with GGUF weights. + +Builds the pipeline, runs all encoders, initializes the scheduler, then +executes exactly one DIT forward pass (step_pre → infer → step_post). +This isolates the GGUF fp16 DIT from the rest of the pipeline. + +Run: + docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \ + python -m tests.component.test_12_dit_single_step +""" +from __future__ import annotations + +import copy +import os +import sys + +import torch + +from tests.component._common import ensure_sample_avatar, get_logger + +log = get_logger("test_12") + +DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M") + +if DIT_QUANT.startswith("gguf-"): + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json" + DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF" +else: + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + DIT_REPO = "lightx2v/Wan2.2-Distill-Models" + + +def run(): + try: + from server.video_models.wan22 import Wan22Pipeline + except ImportError as e: + log.error("Import failed: %s", e) + sys.exit(0) + + avatar = ensure_sample_avatar() + log.info("Avatar: %s", avatar) + + log.info("Building pipeline (quant=%s)...", DIT_QUANT) + pipe = Wan22Pipeline( + base_repo="Wan-AI/Wan2.2-I2V-A14B", + dit_repo=DIT_REPO, + config_json=CONFIG_JSON, + model_cls="wan2.2_moe_distill", + resolution=480, + fps=16, + dit_quant_scheme=DIT_QUANT, + t5_quantized=True, + ) + log.info("Pipeline ready.") + + runner = pipe._runner + + # Set up input_info for a short clip + from lightx2v.utils.input_info import update_input_info_from_dict + update_input_info_from_dict( + pipe._input_info_template, + { + "seed": 42, + "prompt": "a person looking at the camera, natural lighting", + "negative_prompt": "", + "image_path": avatar, + "target_video_length": 17, # 1 second at 16fps + 1 + }, + ) + runner.input_info = pipe._input_info_template + + # 1. Run all encoders (T5 + CLIP + VAE) + log.info("Running all input encoders (T5 + CLIP + VAE)...") + runner.inputs = runner.run_input_encoder() + log.info("Encoder outputs ready.") + for k, v in runner.inputs.items(): + if isinstance(v, torch.Tensor): + log.info(" inputs[%s]: shape=%s dtype=%s", k, v.shape, v.dtype) + elif isinstance(v, dict): + for k2, v2 in v.items(): + if isinstance(v2, torch.Tensor): + log.info(" inputs[%s][%s]: shape=%s dtype=%s", k, k2, v2.shape, v2.dtype) + + # 2. Initialize run (sets up scheduler, creates noise latents) + log.info("Initializing run (scheduler.prepare)...") + runner.init_run() + latents = runner.model.scheduler.latents + log.info("Initial latents: shape=%s dtype=%s", latents.shape, latents.dtype) + + # 3. Single DIT step + log.info("Running single DIT step (step_pre → infer → step_post)...") + runner.model.scheduler.step_pre(step_index=0) + runner.model.infer(runner.inputs) + runner.model.scheduler.step_post() + + latents_after = runner.model.scheduler.latents + log.info("Latents after step: shape=%s dtype=%s", latents_after.shape, latents_after.dtype) + + # Verify latents changed (denoising did something) + assert not torch.equal(latents, latents_after), "Latents unchanged after DIT step" + log.info("PASS: DIT single step completed, latents updated.") + + log.info("PASS: DIT forward pass succeeded under %s pipeline.", DIT_QUANT) + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_13_vae_decode.py b/tests/component/test_13_vae_decode.py new file mode 100644 index 0000000..f21a1bb --- /dev/null +++ b/tests/component/test_13_vae_decode.py @@ -0,0 +1,109 @@ +"""Smoke test: VAE decoder under GGUF pipeline. + +Builds the pipeline, runs all encoders, initializes the scheduler, executes +one DIT denoising step, then decodes the resulting latents back to pixel +frames via the VAE decoder. Validates the full encode→denoise→decode path. + +Run: + docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \ + python -m tests.component.test_13_vae_decode +""" +from __future__ import annotations + +import os +import sys + +import torch + +from tests.component._common import ensure_sample_avatar, get_logger, write_bytes + +log = get_logger("test_13") + +DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M") + +if DIT_QUANT.startswith("gguf-"): + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json" + DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF" +else: + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + DIT_REPO = "lightx2v/Wan2.2-Distill-Models" + + +def run(): + try: + from server.video_models.wan22 import Wan22Pipeline + except ImportError as e: + log.error("Import failed: %s", e) + sys.exit(0) + + avatar = ensure_sample_avatar() + log.info("Avatar: %s", avatar) + + log.info("Building pipeline (quant=%s)...", DIT_QUANT) + pipe = Wan22Pipeline( + base_repo="Wan-AI/Wan2.2-I2V-A14B", + dit_repo=DIT_REPO, + config_json=CONFIG_JSON, + model_cls="wan2.2_moe_distill", + resolution=480, + fps=16, + dit_quant_scheme=DIT_QUANT, + t5_quantized=True, + ) + log.info("Pipeline ready.") + + runner = pipe._runner + + # Set up input_info for a short clip + from lightx2v.utils.input_info import update_input_info_from_dict + update_input_info_from_dict( + pipe._input_info_template, + { + "seed": 42, + "prompt": "a person looking at the camera, natural lighting", + "negative_prompt": "", + "image_path": avatar, + "target_video_length": 17, # 1 second at 16fps + 1 + }, + ) + runner.input_info = pipe._input_info_template + + # 1. Run all encoders (T5 + CLIP + VAE) + log.info("Running all input encoders (T5 + CLIP + VAE)...") + runner.inputs = runner.run_input_encoder() + log.info("Encoder outputs ready.") + + # 2. Initialize run (sets up scheduler, creates noise latents) + log.info("Initializing run (scheduler.prepare)...") + runner.init_run() + log.info("Initial latents: shape=%s dtype=%s", + runner.model.scheduler.latents.shape, + runner.model.scheduler.latents.dtype) + + # 3. Single DIT step (so we have realistic latents to decode) + log.info("Running single DIT step...") + runner.model.scheduler.step_pre(step_index=0) + runner.model.infer(runner.inputs) + runner.model.scheduler.step_post() + latents = runner.model.scheduler.latents + log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype) + + # 4. VAE decode + log.info("Running VAE decoder...") + video_out = runner.run_vae_decoder(latents) + log.info("VAE decoder output type: %s", type(video_out)) + if isinstance(video_out, torch.Tensor): + log.info("video_out: shape=%s dtype=%s device=%s", + video_out.shape, video_out.dtype, video_out.device) + elif isinstance(video_out, list): + log.info("video_out: list of %d items", len(video_out)) + if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor): + log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype) + else: + log.info("video_out: %s", video_out) + + log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT) + + +if __name__ == "__main__": + run()