test passing
This commit is contained in:
@@ -220,6 +220,8 @@ class Wan22Pipeline:
|
||||
GET_DTYPE.cache_clear()
|
||||
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
||||
self._patch_t5_dtype_for_gguf()
|
||||
self._patch_vae_dtype_for_gguf()
|
||||
self._patch_dit_fp32_weights_for_gguf()
|
||||
|
||||
# --- GGUF dtype compatibility patch ----------------------------------------
|
||||
|
||||
@@ -240,6 +242,7 @@ class Wan22Pipeline:
|
||||
orig_run_text_encoder = runner.run_text_encoder.__func__
|
||||
|
||||
def bf16_text_encoder(self_runner, *args, **kwargs):
|
||||
import torch
|
||||
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
||||
os.environ["DTYPE"] = "BF16"
|
||||
GET_DTYPE.cache_clear()
|
||||
@@ -251,11 +254,105 @@ class Wan22Pipeline:
|
||||
os.environ["DTYPE"] = "FP16"
|
||||
GET_DTYPE.cache_clear()
|
||||
GET_SENSITIVE_DTYPE.cache_clear()
|
||||
return result
|
||||
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
|
||||
def _to_fp16(x):
|
||||
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
|
||||
return x.to(torch.float16)
|
||||
if isinstance(x, list):
|
||||
return [_to_fp16(v) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
return tuple(_to_fp16(v) for v in x)
|
||||
if isinstance(x, dict):
|
||||
return {k: _to_fp16(v) for k, v in x.items()}
|
||||
return x
|
||||
return _to_fp16(result)
|
||||
|
||||
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
||||
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
||||
|
||||
def _patch_vae_dtype_for_gguf(self) -> None:
|
||||
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
|
||||
|
||||
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
|
||||
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
|
||||
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
|
||||
VAE is a plain float model (not quantized), simply converting its
|
||||
weights to fp16 avoids both input-vs-weight mismatches and the need
|
||||
for any runtime dtype juggling.
|
||||
"""
|
||||
import torch
|
||||
|
||||
runner = self._runner
|
||||
for name in ("vae_encoder", "vae_decoder"):
|
||||
mod = getattr(runner, name, None)
|
||||
if mod is None:
|
||||
continue
|
||||
inner = getattr(mod, "model", mod)
|
||||
if hasattr(inner, "to"):
|
||||
inner.to(dtype=torch.float16)
|
||||
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
|
||||
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
|
||||
# the first op upcasts fp16 latents back to fp32/bf16.
|
||||
for attr in ("mean", "inv_std"):
|
||||
t = getattr(mod, attr, None)
|
||||
if isinstance(t, torch.Tensor):
|
||||
setattr(mod, attr, t.to(torch.float16))
|
||||
scale = getattr(mod, "scale", None)
|
||||
if isinstance(scale, list):
|
||||
mod.scale = [
|
||||
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
|
||||
for t in scale
|
||||
]
|
||||
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
|
||||
|
||||
def _patch_dit_fp32_weights_for_gguf(self) -> None:
|
||||
"""Cast leftover fp32 DIT pre/post weights to fp16.
|
||||
|
||||
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
|
||||
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
|
||||
up loaded as fp32. That breaks the first conv in the DIT forward pass
|
||||
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
|
||||
runs uniformly in fp16.
|
||||
"""
|
||||
import torch
|
||||
|
||||
runner = self._runner
|
||||
models = getattr(runner.model, "model", None)
|
||||
if models is None:
|
||||
return
|
||||
if not isinstance(models, (list, tuple)):
|
||||
models = [models]
|
||||
|
||||
n_cast = 0
|
||||
for m in models:
|
||||
for weights_attr in ("pre_weight", "post_weight"):
|
||||
w = getattr(m, weights_attr, None)
|
||||
if w is None:
|
||||
continue
|
||||
for sub_name in dir(w):
|
||||
if sub_name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
sub = getattr(w, sub_name)
|
||||
except Exception:
|
||||
continue
|
||||
if sub is None:
|
||||
continue
|
||||
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
|
||||
t = getattr(sub, t_name, None)
|
||||
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
|
||||
casted = t.to(torch.float16)
|
||||
# Preserve pinned-memory status on pin_* tensors so
|
||||
# move_attr_to_cuda's non-blocking H2D copy is safe.
|
||||
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
|
||||
try:
|
||||
casted = casted.pin_memory()
|
||||
except RuntimeError:
|
||||
pass
|
||||
setattr(sub, t_name, casted)
|
||||
n_cast += 1
|
||||
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user