test passing

This commit is contained in:
2026-04-12 16:38:44 -04:00
parent fcf0be38bc
commit 56923ff424
5 changed files with 435 additions and 4 deletions
+98 -1
View File
@@ -220,6 +220,8 @@ class Wan22Pipeline:
GET_DTYPE.cache_clear()
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
self._patch_t5_dtype_for_gguf()
self._patch_vae_dtype_for_gguf()
self._patch_dit_fp32_weights_for_gguf()
# --- GGUF dtype compatibility patch ----------------------------------------
@@ -240,6 +242,7 @@ class Wan22Pipeline:
orig_run_text_encoder = runner.run_text_encoder.__func__
def bf16_text_encoder(self_runner, *args, **kwargs):
import torch
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
os.environ["DTYPE"] = "BF16"
GET_DTYPE.cache_clear()
@@ -251,11 +254,105 @@ class Wan22Pipeline:
os.environ["DTYPE"] = "FP16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
return result
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
def _to_fp16(x):
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
return x.to(torch.float16)
if isinstance(x, list):
return [_to_fp16(v) for v in x]
if isinstance(x, tuple):
return tuple(_to_fp16(v) for v in x)
if isinstance(x, dict):
return {k: _to_fp16(v) for k, v in x.items()}
return x
return _to_fp16(result)
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
def _patch_vae_dtype_for_gguf(self) -> None:
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
VAE is a plain float model (not quantized), simply converting its
weights to fp16 avoids both input-vs-weight mismatches and the need
for any runtime dtype juggling.
"""
import torch
runner = self._runner
for name in ("vae_encoder", "vae_decoder"):
mod = getattr(runner, name, None)
if mod is None:
continue
inner = getattr(mod, "model", mod)
if hasattr(inner, "to"):
inner.to(dtype=torch.float16)
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
# the first op upcasts fp16 latents back to fp32/bf16.
for attr in ("mean", "inv_std"):
t = getattr(mod, attr, None)
if isinstance(t, torch.Tensor):
setattr(mod, attr, t.to(torch.float16))
scale = getattr(mod, "scale", None)
if isinstance(scale, list):
mod.scale = [
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
for t in scale
]
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
def _patch_dit_fp32_weights_for_gguf(self) -> None:
"""Cast leftover fp32 DIT pre/post weights to fp16.
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
up loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
runs uniformly in fp16.
"""
import torch
runner = self._runner
models = getattr(runner.model, "model", None)
if models is None:
return
if not isinstance(models, (list, tuple)):
models = [models]
n_cast = 0
for m in models:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
# Preserve pinned-memory status on pin_* tensors so
# move_attr_to_cuda's non-blocking H2D copy is safe.
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
# --- Weight provisioning -------------------------------------------------
@staticmethod