test passing
This commit is contained in:
@@ -11,9 +11,15 @@
|
|||||||
"target_width": 480,
|
"target_width": 480,
|
||||||
"fps": 16,
|
"fps": 16,
|
||||||
|
|
||||||
"self_attn_1_type": "flash_attn3",
|
"_comment_attn": "flash_attn3/sageattn3 aren't installed (no Blackwell-ready pre-built wheels). Use PyTorch SDPA which works on SM120.",
|
||||||
"cross_attn_1_type": "flash_attn3",
|
"self_attn_1_type": "torch_sdpa",
|
||||||
"cross_attn_2_type": "flash_attn3",
|
"cross_attn_1_type": "torch_sdpa",
|
||||||
|
"cross_attn_2_type": "torch_sdpa",
|
||||||
|
|
||||||
|
"_comment_modulate": "Triton fuse_scale_shift_kernel segfaults during JIT compile on Blackwell SM120 (triton 3.4 + cu128). Use the PyTorch modulate fallback until the Triton issue is resolved.",
|
||||||
|
"modulate_type": "torch",
|
||||||
|
"_comment_rope": "flashinfer not installed; fall back to PyTorch rope.",
|
||||||
|
"rope_type": "torch",
|
||||||
|
|
||||||
"sample_guide_scale": [3.5, 3.5],
|
"sample_guide_scale": [3.5, 3.5],
|
||||||
"sample_shift": 5.0,
|
"sample_shift": 5.0,
|
||||||
|
|||||||
@@ -220,6 +220,8 @@ class Wan22Pipeline:
|
|||||||
GET_DTYPE.cache_clear()
|
GET_DTYPE.cache_clear()
|
||||||
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
||||||
self._patch_t5_dtype_for_gguf()
|
self._patch_t5_dtype_for_gguf()
|
||||||
|
self._patch_vae_dtype_for_gguf()
|
||||||
|
self._patch_dit_fp32_weights_for_gguf()
|
||||||
|
|
||||||
# --- GGUF dtype compatibility patch ----------------------------------------
|
# --- GGUF dtype compatibility patch ----------------------------------------
|
||||||
|
|
||||||
@@ -240,6 +242,7 @@ class Wan22Pipeline:
|
|||||||
orig_run_text_encoder = runner.run_text_encoder.__func__
|
orig_run_text_encoder = runner.run_text_encoder.__func__
|
||||||
|
|
||||||
def bf16_text_encoder(self_runner, *args, **kwargs):
|
def bf16_text_encoder(self_runner, *args, **kwargs):
|
||||||
|
import torch
|
||||||
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
||||||
os.environ["DTYPE"] = "BF16"
|
os.environ["DTYPE"] = "BF16"
|
||||||
GET_DTYPE.cache_clear()
|
GET_DTYPE.cache_clear()
|
||||||
@@ -251,11 +254,105 @@ class Wan22Pipeline:
|
|||||||
os.environ["DTYPE"] = "FP16"
|
os.environ["DTYPE"] = "FP16"
|
||||||
GET_DTYPE.cache_clear()
|
GET_DTYPE.cache_clear()
|
||||||
GET_SENSITIVE_DTYPE.cache_clear()
|
GET_SENSITIVE_DTYPE.cache_clear()
|
||||||
return result
|
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
|
||||||
|
def _to_fp16(x):
|
||||||
|
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
|
||||||
|
return x.to(torch.float16)
|
||||||
|
if isinstance(x, list):
|
||||||
|
return [_to_fp16(v) for v in x]
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
return tuple(_to_fp16(v) for v in x)
|
||||||
|
if isinstance(x, dict):
|
||||||
|
return {k: _to_fp16(v) for k, v in x.items()}
|
||||||
|
return x
|
||||||
|
return _to_fp16(result)
|
||||||
|
|
||||||
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
||||||
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
||||||
|
|
||||||
|
def _patch_vae_dtype_for_gguf(self) -> None:
|
||||||
|
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
|
||||||
|
|
||||||
|
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
|
||||||
|
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
|
||||||
|
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
|
||||||
|
VAE is a plain float model (not quantized), simply converting its
|
||||||
|
weights to fp16 avoids both input-vs-weight mismatches and the need
|
||||||
|
for any runtime dtype juggling.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
runner = self._runner
|
||||||
|
for name in ("vae_encoder", "vae_decoder"):
|
||||||
|
mod = getattr(runner, name, None)
|
||||||
|
if mod is None:
|
||||||
|
continue
|
||||||
|
inner = getattr(mod, "model", mod)
|
||||||
|
if hasattr(inner, "to"):
|
||||||
|
inner.to(dtype=torch.float16)
|
||||||
|
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
|
||||||
|
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
|
||||||
|
# the first op upcasts fp16 latents back to fp32/bf16.
|
||||||
|
for attr in ("mean", "inv_std"):
|
||||||
|
t = getattr(mod, attr, None)
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
setattr(mod, attr, t.to(torch.float16))
|
||||||
|
scale = getattr(mod, "scale", None)
|
||||||
|
if isinstance(scale, list):
|
||||||
|
mod.scale = [
|
||||||
|
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
|
||||||
|
for t in scale
|
||||||
|
]
|
||||||
|
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||||
|
|
||||||
|
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||||
|
"""Cast leftover fp32 DIT pre/post weights to fp16.
|
||||||
|
|
||||||
|
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
|
||||||
|
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
|
||||||
|
up loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||||
|
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
|
||||||
|
runs uniformly in fp16.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
runner = self._runner
|
||||||
|
models = getattr(runner.model, "model", None)
|
||||||
|
if models is None:
|
||||||
|
return
|
||||||
|
if not isinstance(models, (list, tuple)):
|
||||||
|
models = [models]
|
||||||
|
|
||||||
|
n_cast = 0
|
||||||
|
for m in models:
|
||||||
|
for weights_attr in ("pre_weight", "post_weight"):
|
||||||
|
w = getattr(m, weights_attr, None)
|
||||||
|
if w is None:
|
||||||
|
continue
|
||||||
|
for sub_name in dir(w):
|
||||||
|
if sub_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
sub = getattr(w, sub_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if sub is None:
|
||||||
|
continue
|
||||||
|
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||||
|
t = getattr(sub, t_name, None)
|
||||||
|
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||||
|
casted = t.to(torch.float16)
|
||||||
|
# Preserve pinned-memory status on pin_* tensors so
|
||||||
|
# move_attr_to_cuda's non-blocking H2D copy is safe.
|
||||||
|
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||||
|
try:
|
||||||
|
casted = casted.pin_memory()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
setattr(sub, t_name, casted)
|
||||||
|
n_cast += 1
|
||||||
|
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
|
||||||
|
|
||||||
# --- Weight provisioning -------------------------------------------------
|
# --- Weight provisioning -------------------------------------------------
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -0,0 +1,112 @@
|
|||||||
|
"""Smoke test: image reading + VAE encoder (+ CLIP if enabled) under GGUF pipeline.
|
||||||
|
|
||||||
|
Builds the Wan22Pipeline, loads a sample avatar, reads the image input,
|
||||||
|
runs the CLIP image encoder (if use_image_encoder is true in the config),
|
||||||
|
and runs the VAE encoder. Validates outputs under DTYPE=FP16.
|
||||||
|
|
||||||
|
Note: The GGUF distill config sets use_image_encoder=false, so CLIP is
|
||||||
|
skipped by default. The VAE encoder is always exercised.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_11_image_encode
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_11")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||||
|
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||||
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||||
|
else:
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2_moe_distill",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
runner = pipe._runner
|
||||||
|
|
||||||
|
# Set up input_info so runner methods can access it
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||||
|
update_input_info_from_dict(
|
||||||
|
pipe._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": 42,
|
||||||
|
"prompt": "a person looking at the camera, natural lighting",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": avatar,
|
||||||
|
"target_video_length": 17,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner.input_info = pipe._input_info_template
|
||||||
|
|
||||||
|
# 1. Load image
|
||||||
|
log.info("Reading image input...")
|
||||||
|
img, img_ori = runner.read_image_input(avatar)
|
||||||
|
log.info("img: shape=%s dtype=%s device=%s", img.shape, img.dtype, img.device)
|
||||||
|
|
||||||
|
# 2. CLIP image encoder (only if enabled in config)
|
||||||
|
use_clip = runner.config.get("use_image_encoder", True)
|
||||||
|
if use_clip:
|
||||||
|
log.info("Running CLIP image encoder...")
|
||||||
|
clip_out = runner.run_image_encoder(img)
|
||||||
|
log.info("clip_out: shape=%s dtype=%s device=%s", clip_out.shape, clip_out.dtype, clip_out.device)
|
||||||
|
assert isinstance(clip_out, torch.Tensor), f"Expected tensor, got {type(clip_out)}"
|
||||||
|
log.info("PASS: CLIP image encoder succeeded.")
|
||||||
|
else:
|
||||||
|
log.info("CLIP image encoder disabled (use_image_encoder=false) — skipping.")
|
||||||
|
|
||||||
|
# 3. VAE encoder
|
||||||
|
vae_input = img_ori if runner.vae_encoder_need_img_original else img
|
||||||
|
log.info("Running VAE encoder (using %s)...",
|
||||||
|
"img_ori" if runner.vae_encoder_need_img_original else "img tensor")
|
||||||
|
vae_out, latent_shape = runner.run_vae_encoder(vae_input)
|
||||||
|
log.info("latent_shape: %s", latent_shape)
|
||||||
|
if isinstance(vae_out, torch.Tensor):
|
||||||
|
log.info("vae_out: shape=%s dtype=%s device=%s", vae_out.shape, vae_out.dtype, vae_out.device)
|
||||||
|
elif isinstance(vae_out, dict):
|
||||||
|
for k, v in vae_out.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
log.info("vae_out[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
|
||||||
|
else:
|
||||||
|
log.info("vae_out[%s]: type=%s", k, type(v))
|
||||||
|
else:
|
||||||
|
log.info("vae_out: type=%s", type(vae_out))
|
||||||
|
log.info("PASS: VAE encoder succeeded.")
|
||||||
|
|
||||||
|
log.info("PASS: All image encoding stages completed under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Smoke test: single DIT denoising step with GGUF weights.
|
||||||
|
|
||||||
|
Builds the pipeline, runs all encoders, initializes the scheduler, then
|
||||||
|
executes exactly one DIT forward pass (step_pre → infer → step_post).
|
||||||
|
This isolates the GGUF fp16 DIT from the rest of the pipeline.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_12_dit_single_step
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_12")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||||
|
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||||
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||||
|
else:
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2_moe_distill",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
runner = pipe._runner
|
||||||
|
|
||||||
|
# Set up input_info for a short clip
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||||
|
update_input_info_from_dict(
|
||||||
|
pipe._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": 42,
|
||||||
|
"prompt": "a person looking at the camera, natural lighting",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": avatar,
|
||||||
|
"target_video_length": 17, # 1 second at 16fps + 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner.input_info = pipe._input_info_template
|
||||||
|
|
||||||
|
# 1. Run all encoders (T5 + CLIP + VAE)
|
||||||
|
log.info("Running all input encoders (T5 + CLIP + VAE)...")
|
||||||
|
runner.inputs = runner.run_input_encoder()
|
||||||
|
log.info("Encoder outputs ready.")
|
||||||
|
for k, v in runner.inputs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
log.info(" inputs[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
for k2, v2 in v.items():
|
||||||
|
if isinstance(v2, torch.Tensor):
|
||||||
|
log.info(" inputs[%s][%s]: shape=%s dtype=%s", k, k2, v2.shape, v2.dtype)
|
||||||
|
|
||||||
|
# 2. Initialize run (sets up scheduler, creates noise latents)
|
||||||
|
log.info("Initializing run (scheduler.prepare)...")
|
||||||
|
runner.init_run()
|
||||||
|
latents = runner.model.scheduler.latents
|
||||||
|
log.info("Initial latents: shape=%s dtype=%s", latents.shape, latents.dtype)
|
||||||
|
|
||||||
|
# 3. Single DIT step
|
||||||
|
log.info("Running single DIT step (step_pre → infer → step_post)...")
|
||||||
|
runner.model.scheduler.step_pre(step_index=0)
|
||||||
|
runner.model.infer(runner.inputs)
|
||||||
|
runner.model.scheduler.step_post()
|
||||||
|
|
||||||
|
latents_after = runner.model.scheduler.latents
|
||||||
|
log.info("Latents after step: shape=%s dtype=%s", latents_after.shape, latents_after.dtype)
|
||||||
|
|
||||||
|
# Verify latents changed (denoising did something)
|
||||||
|
assert not torch.equal(latents, latents_after), "Latents unchanged after DIT step"
|
||||||
|
log.info("PASS: DIT single step completed, latents updated.")
|
||||||
|
|
||||||
|
log.info("PASS: DIT forward pass succeeded under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
"""Smoke test: VAE decoder under GGUF pipeline.
|
||||||
|
|
||||||
|
Builds the pipeline, runs all encoders, initializes the scheduler, executes
|
||||||
|
one DIT denoising step, then decodes the resulting latents back to pixel
|
||||||
|
frames via the VAE decoder. Validates the full encode→denoise→decode path.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_13_vae_decode
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||||
|
|
||||||
|
log = get_logger("test_13")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||||
|
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||||
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||||
|
else:
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2_moe_distill",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
runner = pipe._runner
|
||||||
|
|
||||||
|
# Set up input_info for a short clip
|
||||||
|
from lightx2v.utils.input_info import update_input_info_from_dict
|
||||||
|
update_input_info_from_dict(
|
||||||
|
pipe._input_info_template,
|
||||||
|
{
|
||||||
|
"seed": 42,
|
||||||
|
"prompt": "a person looking at the camera, natural lighting",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"image_path": avatar,
|
||||||
|
"target_video_length": 17, # 1 second at 16fps + 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner.input_info = pipe._input_info_template
|
||||||
|
|
||||||
|
# 1. Run all encoders (T5 + CLIP + VAE)
|
||||||
|
log.info("Running all input encoders (T5 + CLIP + VAE)...")
|
||||||
|
runner.inputs = runner.run_input_encoder()
|
||||||
|
log.info("Encoder outputs ready.")
|
||||||
|
|
||||||
|
# 2. Initialize run (sets up scheduler, creates noise latents)
|
||||||
|
log.info("Initializing run (scheduler.prepare)...")
|
||||||
|
runner.init_run()
|
||||||
|
log.info("Initial latents: shape=%s dtype=%s",
|
||||||
|
runner.model.scheduler.latents.shape,
|
||||||
|
runner.model.scheduler.latents.dtype)
|
||||||
|
|
||||||
|
# 3. Single DIT step (so we have realistic latents to decode)
|
||||||
|
log.info("Running single DIT step...")
|
||||||
|
runner.model.scheduler.step_pre(step_index=0)
|
||||||
|
runner.model.infer(runner.inputs)
|
||||||
|
runner.model.scheduler.step_post()
|
||||||
|
latents = runner.model.scheduler.latents
|
||||||
|
log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype)
|
||||||
|
|
||||||
|
# 4. VAE decode
|
||||||
|
log.info("Running VAE decoder...")
|
||||||
|
video_out = runner.run_vae_decoder(latents)
|
||||||
|
log.info("VAE decoder output type: %s", type(video_out))
|
||||||
|
if isinstance(video_out, torch.Tensor):
|
||||||
|
log.info("video_out: shape=%s dtype=%s device=%s",
|
||||||
|
video_out.shape, video_out.dtype, video_out.device)
|
||||||
|
elif isinstance(video_out, list):
|
||||||
|
log.info("video_out: list of %d items", len(video_out))
|
||||||
|
if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor):
|
||||||
|
log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype)
|
||||||
|
else:
|
||||||
|
log.info("video_out: %s", video_out)
|
||||||
|
|
||||||
|
log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
Reference in New Issue
Block a user