t5 encoder fp8 seems to be working
This commit is contained in:
+29
-13
@@ -58,15 +58,18 @@ class VideoConfig:
|
||||
loras: list[LoRASpec] = field(default_factory=list)
|
||||
|
||||
# Model paths — can be overridden via config.yml.video.models.
|
||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||
# The bf16 DIT shards in this repo are skipped — we
|
||||
# replace them with the fp8 files from wan22_fp8_repo.
|
||||
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
|
||||
# 4-step distilled DIT checkpoints (~15 GB each).
|
||||
# wan22_config_json: path to the LightX2V inference config template the
|
||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||
# The bf16 DIT shards in this repo are skipped — we
|
||||
# replace them with quantised files from wan22_dit_repo.
|
||||
# wan22_dit_repo : HF repo id (or local dir) providing the quantised
|
||||
# DIT checkpoints (fp8 or GGUF).
|
||||
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
|
||||
# wan22_config_json : path to the LightX2V inference config template the
|
||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
||||
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||
wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||
wan22_dit_quant_scheme: str = "fp8-sgl"
|
||||
wan22_t5_quantized: bool = False
|
||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
wan22_model_cls: str = "wan2.2_moe_distill"
|
||||
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||
@@ -121,8 +124,18 @@ class VideoConfig:
|
||||
wan22_base_repo=str(
|
||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
||||
),
|
||||
wan22_fp8_repo=str(
|
||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
|
||||
wan22_dit_repo=str(
|
||||
models_raw.get(
|
||||
"wan22_dit_repo",
|
||||
# Backwards compat: fall back to old key name.
|
||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
|
||||
)
|
||||
),
|
||||
wan22_dit_quant_scheme=str(
|
||||
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
|
||||
),
|
||||
wan22_t5_quantized=bool(
|
||||
models_raw.get("wan22_t5_quantized", False)
|
||||
),
|
||||
wan22_config_json=str(
|
||||
models_raw.get(
|
||||
@@ -204,16 +217,19 @@ class VideoEngine:
|
||||
from server.video_models.musetalk import MuseTalkEngine
|
||||
|
||||
log.info(
|
||||
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
|
||||
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
|
||||
"Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
|
||||
self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
|
||||
self.cfg.wan22_dit_quant_scheme,
|
||||
)
|
||||
self._wan22 = Wan22Pipeline(
|
||||
base_repo=self.cfg.wan22_base_repo,
|
||||
fp8_repo=self.cfg.wan22_fp8_repo,
|
||||
dit_repo=self.cfg.wan22_dit_repo,
|
||||
config_json=self.cfg.wan22_config_json,
|
||||
model_cls=self.cfg.wan22_model_cls,
|
||||
resolution=self.cfg.resolution,
|
||||
fps=self.cfg.fps,
|
||||
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
|
||||
t5_quantized=self.cfg.wan22_t5_quantized,
|
||||
)
|
||||
if self.cfg.loras:
|
||||
self._wan22.load_loras(self.cfg.loras)
|
||||
|
||||
+200
-35
@@ -1,4 +1,4 @@
|
||||
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
|
||||
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
|
||||
|
||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||
@@ -25,10 +25,12 @@ Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||
The bf16 DIT shards under high_noise_model/
|
||||
and low_noise_model/ are SKIPPED via
|
||||
ignore_patterns — we replace them with fp8.
|
||||
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
|
||||
the fp8 e4m3 4-step distilled high/low
|
||||
noise DIT checkpoints (~15 GB each).
|
||||
ignore_patterns — we replace them with
|
||||
quantised checkpoints from dit_repo.
|
||||
- dit_repo (configurable) — quantised DIT checkpoints. Supported
|
||||
formats:
|
||||
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
|
||||
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -47,13 +49,22 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# --- fp8 distill filenames --------------------------------------------------
|
||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||
|
||||
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
|
||||
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
|
||||
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
|
||||
|
||||
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||
T5_FP8_REPO = "lightx2v/Encoders"
|
||||
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||
|
||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
|
||||
# files from the distill repo replace the DIT weights entirely. We must keep
|
||||
# the config.json / index.json metadata under high_noise_model/ and
|
||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
|
||||
# quantised files from dit_repo replace the DIT weights entirely. We must
|
||||
# keep the config.json / index.json metadata under high_noise_model/ and
|
||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||
# ``dim`` from them) and the tokenizer files under google/.
|
||||
BASE_REPO_IGNORE_PATTERNS = [
|
||||
@@ -66,8 +77,68 @@ BASE_REPO_IGNORE_PATTERNS = [
|
||||
]
|
||||
|
||||
|
||||
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||
|
||||
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
|
||||
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
|
||||
including Blackwell. This patch is idempotent.
|
||||
"""
|
||||
try:
|
||||
import sgl_kernel # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
return # no sgl_kernel → fp8 T5 not in use
|
||||
|
||||
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
cap = torch.cuda.get_device_capability()
|
||||
if cap[0] < 12:
|
||||
return # only patch on Blackwell+
|
||||
|
||||
_orig = sgl_kernel.fp8_scaled_mm
|
||||
|
||||
def _torch_fp8_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
b_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
|
||||
# scale_a: scalar or (M,1)
|
||||
# scale_b: scalar or (1,N)
|
||||
# sgl_kernel provides scale_b as (N,1) — transpose it.
|
||||
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
|
||||
b_scale = b_scale.t()
|
||||
# _scaled_mm requires B to be column-major (stride(0)==1).
|
||||
bt = b.t().contiguous().t()
|
||||
out = torch._scaled_mm(
|
||||
a, bt,
|
||||
scale_a=a_scale, scale_b=b_scale,
|
||||
out_dtype=out_dtype, bias=bias,
|
||||
)
|
||||
return out
|
||||
|
||||
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
|
||||
sgl_kernel._fp8_patched_for_blackwell = True
|
||||
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
|
||||
|
||||
|
||||
class Wan22Pipeline:
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
|
||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
|
||||
|
||||
Supports two DIT quantisation formats:
|
||||
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
|
||||
``lightx2v/Wan2.2-Distill-Models``)
|
||||
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
|
||||
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
|
||||
|
||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||
@@ -77,23 +148,34 @@ class Wan22Pipeline:
|
||||
def __init__(
|
||||
self,
|
||||
base_repo: str,
|
||||
fp8_repo: str,
|
||||
dit_repo: str,
|
||||
config_json: str,
|
||||
model_cls: str = "wan2.2_moe_distill",
|
||||
resolution: int = 480,
|
||||
fps: int = 16,
|
||||
dit_quant_scheme: str = "fp8-sgl",
|
||||
t5_quantized: bool = False,
|
||||
):
|
||||
self.base_repo = base_repo
|
||||
self.fp8_repo = fp8_repo
|
||||
self.dit_repo = dit_repo
|
||||
self.config_json_template = config_json
|
||||
self.model_cls = model_cls
|
||||
self.resolution = resolution
|
||||
self.fps = fps
|
||||
self.dit_quant_scheme = dit_quant_scheme
|
||||
self.t5_quantized = t5_quantized
|
||||
self._applied_loras: list[LoRASpec] = []
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
|
||||
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
|
||||
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
|
||||
self._model_root = self._ensure_base_repo(base_repo)
|
||||
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
|
||||
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
|
||||
dit_repo, dit_quant_scheme,
|
||||
)
|
||||
self._t5_fp8_ckpt = (
|
||||
self._ensure_t5_fp8() if t5_quantized else None
|
||||
)
|
||||
|
||||
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
||||
self._runtime_json_path = self._build_runtime_config()
|
||||
@@ -105,13 +187,17 @@ class Wan22Pipeline:
|
||||
config_json=self._runtime_json_path,
|
||||
)
|
||||
|
||||
# 4. set_config → init_runner. Runner construction triggers weight load.
|
||||
# Imports are scoped here so ``import server.video_models.wan22``
|
||||
# never pulls in lightx2v (tests can import this module on CPU).
|
||||
# 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
|
||||
# never pulls in lightx2v — tests can import this module on CPU).
|
||||
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
||||
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
||||
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
||||
|
||||
_patch_fp8_scaled_mm_for_blackwell()
|
||||
|
||||
# 5. Load all models under default DTYPE=BF16 so T5 (which is
|
||||
# hardcoded to bf16 weights) initialises its offload buffers
|
||||
# correctly. We flip to FP16 *after* init_runner completes.
|
||||
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
||||
model_cls, self._model_root)
|
||||
self._config = set_config(args)
|
||||
@@ -124,6 +210,52 @@ class Wan22Pipeline:
|
||||
self._runner = init_runner(self._config)
|
||||
log.info("LightX2V runner loaded; weights resident.")
|
||||
|
||||
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
|
||||
# dequantises to fp16, and many intermediate tensors inside the
|
||||
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
|
||||
# is wrapped to temporarily restore BF16 during its forward.
|
||||
if self._is_gguf:
|
||||
os.environ["DTYPE"] = "FP16"
|
||||
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
|
||||
GET_DTYPE.cache_clear()
|
||||
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
||||
self._patch_t5_dtype_for_gguf()
|
||||
|
||||
# --- GGUF dtype compatibility patch ----------------------------------------
|
||||
|
||||
def _patch_t5_dtype_for_gguf(self) -> None:
|
||||
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
|
||||
|
||||
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
|
||||
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
|
||||
path breaks because intermediate tensor dtypes no longer match the bf16
|
||||
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
|
||||
back to bf16, then restore fp16 before the DIT runs.
|
||||
"""
|
||||
import os
|
||||
import types
|
||||
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
|
||||
|
||||
runner = self._runner
|
||||
orig_run_text_encoder = runner.run_text_encoder.__func__
|
||||
|
||||
def bf16_text_encoder(self_runner, *args, **kwargs):
|
||||
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
||||
os.environ["DTYPE"] = "BF16"
|
||||
GET_DTYPE.cache_clear()
|
||||
GET_SENSITIVE_DTYPE.cache_clear()
|
||||
try:
|
||||
result = orig_run_text_encoder(self_runner, *args, **kwargs)
|
||||
finally:
|
||||
# Restore FP16 for the DIT / rest of the pipeline.
|
||||
os.environ["DTYPE"] = "FP16"
|
||||
GET_DTYPE.cache_clear()
|
||||
GET_SENSITIVE_DTYPE.cache_clear()
|
||||
return result
|
||||
|
||||
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
||||
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
||||
|
||||
# --- Weight provisioning -------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
@@ -132,7 +264,7 @@ class Wan22Pipeline:
|
||||
|
||||
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
||||
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
||||
shards (they're replaced by the fp8 files).
|
||||
shards (they're replaced by the quantised files).
|
||||
"""
|
||||
if os.path.isdir(base_repo):
|
||||
return base_repo
|
||||
@@ -145,42 +277,75 @@ class Wan22Pipeline:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
|
||||
def _ensure_dit_checkpoints(
|
||||
dit_repo: str,
|
||||
dit_quant_scheme: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Return (high_noise_path, low_noise_path) for the DIT pair.
|
||||
|
||||
- If ``fp8_repo`` is a local directory, expect both files inside it.
|
||||
- Otherwise treat it as a HF repo id and download only the two files
|
||||
we need (not the ~150 GB of other variants in that repo).
|
||||
Supports both fp8 safetensors and GGUF formats.
|
||||
"""
|
||||
if not fp8_repo:
|
||||
raise ValueError("fp8_repo must be a HF repo id or local directory.")
|
||||
if os.path.isdir(fp8_repo):
|
||||
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
|
||||
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
|
||||
if not dit_repo:
|
||||
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||
|
||||
is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||
|
||||
if is_gguf:
|
||||
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
|
||||
quant = dit_quant_scheme.replace("gguf-", "")
|
||||
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
|
||||
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
|
||||
else:
|
||||
high_file = FP8_HIGH_NOISE_FILE
|
||||
low_file = FP8_LOW_NOISE_FILE
|
||||
|
||||
# Local directory?
|
||||
if os.path.isdir(dit_repo):
|
||||
high = os.path.join(dit_repo, high_file)
|
||||
low = os.path.join(dit_repo, low_file)
|
||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||
raise FileNotFoundError(
|
||||
f"fp8 checkpoints not found in {fp8_repo}: expected "
|
||||
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
|
||||
f"DIT checkpoints not found in {dit_repo}: expected "
|
||||
f"{high_file} and {low_file}"
|
||||
)
|
||||
return high, low
|
||||
|
||||
# HuggingFace download.
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
|
||||
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
|
||||
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
|
||||
log.info("Downloading %s DIT checkpoints from %s ...",
|
||||
dit_quant_scheme, dit_repo)
|
||||
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
|
||||
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def _ensure_t5_fp8() -> str:
|
||||
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
|
||||
|
||||
Returns the local path to the safetensors file (~6 GB).
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
|
||||
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||
|
||||
def _build_runtime_config(self) -> str:
|
||||
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
||||
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
# Drop editorial comments before passing to LightX2V.
|
||||
cfg.pop("_comment", None)
|
||||
cfg["high_noise_quantized_ckpt"] = self._fp8_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._fp8_low
|
||||
cfg["high_noise_quantized_ckpt"] = self._dit_high
|
||||
cfg["low_noise_quantized_ckpt"] = self._dit_low
|
||||
cfg.setdefault("fps", self.fps)
|
||||
|
||||
# T5 fp8 quantization.
|
||||
if self._t5_fp8_ckpt:
|
||||
cfg["t5_quantized"] = True
|
||||
cfg["t5_quant_scheme"] = "fp8-sgl"
|
||||
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
prefix="wan22_fp8_", suffix=".json",
|
||||
prefix="wan22_dit_", suffix=".json",
|
||||
mode="w", delete=False, encoding="utf-8",
|
||||
)
|
||||
json.dump(cfg, tmp, indent=2)
|
||||
|
||||
Reference in New Issue
Block a user