t5 encoder fp8 seems to be working

This commit is contained in:
2026-04-12 13:50:34 -04:00
parent 2818b41004
commit fcf0be38bc
13 changed files with 505 additions and 67 deletions
+8
View File
@@ -53,6 +53,14 @@ RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/ModelTC/LightX2V.git" || \ "git+https://github.com/ModelTC/LightX2V.git" || \
echo "LightX2V install failed — config.video.enabled must stay false until fixed" echo "LightX2V install failed — config.video.enabled must stay false until fixed"
# #
# sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120
# (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead.
# Our wan22.py patches fp8_scaled_mm → torch._scaled_mm at runtime for
# Blackwell GPUs, but the sgl_kernel package itself must still be present.
RUN python3.11 -m pip install --no-cache-dir --no-deps \
"sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \
echo "sgl-kernel install failed — fp8 T5 will fall back to bf16"
#
# MuseTalk (audio-driven lip-sync) — same story. # MuseTalk (audio-driven lip-sync) — same story.
RUN python3.11 -m pip install --no-cache-dir \ RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/TMElyralab/MuseTalk.git" || \ "git+https://github.com/TMElyralab/MuseTalk.git" || \
+14 -7
View File
@@ -32,16 +32,23 @@ video:
casual gestures, natural lighting, soft focus background casual gestures, natural lighting, soft focus background
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint} prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
# Model sources for the video stack. The fp8 e4m3 4-step distilled DIT # Model sources for the video stack. T5/VAE/tokenizer come from the
# weights from lightx2v/Wan2.2-Distill-Models are ~15 GB each (vs ~28 GB # Wan-AI base repo. DIT weights come from wan22_dit_repo in the format
# bf16) — that's the "save VRAM" path. T5/VAE/tokenizer still come from # specified by wan22_dit_quant_scheme. Both repos download on first run
# the Wan-AI base repo. Both repos download on first run into # into HF_HOME=/cache/huggingface.
# HF_HOME=/cache/huggingface. #
# Supported dit_quant_scheme values:
# fp8-sgl — fp8 e4m3 safetensors (~15 GB/expert, from lightx2v/Wan2.2-Distill-Models)
# gguf-Q4_K_M — GGUF 4-bit (~9.65 GB/expert, from QuantStack/Wan2.2-I2V-A14B-GGUF)
# gguf-Q8_0 — GGUF 8-bit (~15.4 GB/expert)
# (any gguf-<level> supported by LightX2V — see base_model.py MM_WEIGHT_REGISTER)
models: models:
wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B
wan22_fp8_repo: lightx2v/Wan2.2-Distill-Models wan22_dit_repo: QuantStack/Wan2.2-I2V-A14B-GGUF
wan22_dit_quant_scheme: gguf-Q4_K_M
wan22_t5_quantized: true
wan22_model_cls: wan2.2_moe_distill wan22_model_cls: wan2.2_moe_distill
wan22_config_json: /app/configs/lightx2v/wan22_i2v_fp8_distill.json wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_distill.json
musetalk_path: TMElyralab/MuseTalk musetalk_path: TMElyralab/MuseTalk
# LoRAs applied to the fp8 base at load time via runtime switch_lora. # LoRAs applied to the fp8 base at load time via runtime switch_lora.
@@ -0,0 +1,35 @@
{
"_comment": "Wan2.2 i2v MoE 4-step distill, GGUF quantized. Uses QuantStack/Wan2.2-I2V-A14B-GGUF checkpoints instead of fp8 safetensors. GGUF does not support block-level offload so offload_granularity is set to 'model' — the entire DIT is moved to GPU when active. With Q4_K_M (~9.65 GB per expert) this fits comfortably in 24+ GB VRAM. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py. IMPORTANT: GGUF dequantizes to fp16, so you must set DTYPE=FP16 in the container environment.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q4_K_M",
"t5_quantized": false
}
+1
View File
@@ -16,6 +16,7 @@ services:
- ./configs:/app/configs:ro - ./configs:/app/configs:ro
- ./server:/app/server:ro - ./server:/app/server:ro
- ./static:/app/static:ro - ./static:/app/static:ro
- ./tests:/app/tests
- ./run.py:/app/run.py:ro - ./run.py:/app/run.py:ro
deploy: deploy:
resources: resources:
+3
View File
@@ -20,6 +20,9 @@ pyyaml
imageio[ffmpeg]>=2.34 imageio[ffmpeg]>=2.34
av>=12.0 av>=12.0
pyzmq>=25.0 pyzmq>=25.0
gguf>=0.6.0
# sgl-kernel: installed from SGLang's cu128 wheel index in Dockerfile
# (PyPI version lacks SM120/Blackwell CUDA kernels)
# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the # LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the
# Dockerfile because neither ships a stable PyPI release yet. See lines # Dockerfile because neither ships a stable PyPI release yet. See lines
# "LightX2V from source" / "MuseTalk from source" in Dockerfile. # "LightX2V from source" / "MuseTalk from source" in Dockerfile.
+25 -9
View File
@@ -60,13 +60,16 @@ class VideoConfig:
# Model paths — can be overridden via config.yml.video.models. # Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer. # wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we # The bf16 DIT shards in this repo are skipped — we
# replace them with the fp8 files from wan22_fp8_repo. # replace them with quantised files from wan22_dit_repo.
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3 # wan22_dit_repo : HF repo id (or local dir) providing the quantised
# 4-step distilled DIT checkpoints (~15 GB each). # DIT checkpoints (fp8 or GGUF).
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
# wan22_config_json : path to the LightX2V inference config template the # wan22_config_json : path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths. # Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B" wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models" wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_dit_quant_scheme: str = "fp8-sgl"
wan22_t5_quantized: bool = False
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
wan22_model_cls: str = "wan2.2_moe_distill" wan22_model_cls: str = "wan2.2_moe_distill"
musetalk_model_path: str = "TMElyralab/MuseTalk" musetalk_model_path: str = "TMElyralab/MuseTalk"
@@ -121,8 +124,18 @@ class VideoConfig:
wan22_base_repo=str( wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B") models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
), ),
wan22_fp8_repo=str( wan22_dit_repo=str(
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models") models_raw.get(
"wan22_dit_repo",
# Backwards compat: fall back to old key name.
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
)
),
wan22_dit_quant_scheme=str(
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
),
wan22_t5_quantized=bool(
models_raw.get("wan22_t5_quantized", False)
), ),
wan22_config_json=str( wan22_config_json=str(
models_raw.get( models_raw.get(
@@ -204,16 +217,19 @@ class VideoEngine:
from server.video_models.musetalk import MuseTalkEngine from server.video_models.musetalk import MuseTalkEngine
log.info( log.info(
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...", "Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo, self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
self.cfg.wan22_dit_quant_scheme,
) )
self._wan22 = Wan22Pipeline( self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo, base_repo=self.cfg.wan22_base_repo,
fp8_repo=self.cfg.wan22_fp8_repo, dit_repo=self.cfg.wan22_dit_repo,
config_json=self.cfg.wan22_config_json, config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls, model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution, resolution=self.cfg.resolution,
fps=self.cfg.fps, fps=self.cfg.fps,
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
t5_quantized=self.cfg.wan22_t5_quantized,
) )
if self.cfg.loras: if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras) self._wan22.load_loras(self.cfg.loras)
+200 -35
View File
@@ -1,4 +1,4 @@
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V. """Wan2.2-Lightning image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main): the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
@@ -25,10 +25,12 @@ Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only. - Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/ The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with fp8. ignore_patterns — we replace them with
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files: quantised checkpoints from dit_repo.
the fp8 e4m3 4-step distilled high/low - dit_repo (configurable) — quantised DIT checkpoints. Supported
noise DIT checkpoints (~15 GB each). formats:
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
""" """
from __future__ import annotations from __future__ import annotations
@@ -47,13 +49,22 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# --- fp8 distill filenames --------------------------------------------------
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the # The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8 # T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
# files from the distill repo replace the DIT weights entirely. We must keep # quantised files from dit_repo replace the DIT weights entirely. We must
# the config.json / index.json metadata under high_noise_model/ and # keep the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like # low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/. # ``dim`` from them) and the tokenizer files under google/.
BASE_REPO_IGNORE_PATTERNS = [ BASE_REPO_IGNORE_PATTERNS = [
@@ -66,8 +77,68 @@ BASE_REPO_IGNORE_PATTERNS = [
] ]
def _patch_fp8_scaled_mm_for_blackwell() -> None:
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
including Blackwell. This patch is idempotent.
"""
try:
import sgl_kernel # type: ignore[import-not-found]
except ImportError:
return # no sgl_kernel → fp8 T5 not in use
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
return
import torch
if not torch.cuda.is_available():
return
cap = torch.cuda.get_device_capability()
if cap[0] < 12:
return # only patch on Blackwell+
_orig = sgl_kernel.fp8_scaled_mm
def _torch_fp8_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
# scale_a: scalar or (M,1)
# scale_b: scalar or (1,N)
# sgl_kernel provides scale_b as (N,1) — transpose it.
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
b_scale = b_scale.t()
# _scaled_mm requires B to be column-major (stride(0)==1).
bt = b.t().contiguous().t()
out = torch._scaled_mm(
a, bt,
scale_a=a_scale, scale_b=b_scale,
out_dtype=out_dtype, bias=bias,
)
return out
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
sgl_kernel._fp8_patched_for_blackwell = True
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
class Wan22Pipeline: class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights. """Wrapper around LightX2V's Wan2.2 MoE distill runner.
Supports two DIT quantisation formats:
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
``lightx2v/Wan2.2-Distill-Models``)
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
Constructor downloads (if needed) both HF repos, writes a runtime JSON Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``. config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
@@ -77,23 +148,34 @@ class Wan22Pipeline:
def __init__( def __init__(
self, self,
base_repo: str, base_repo: str,
fp8_repo: str, dit_repo: str,
config_json: str, config_json: str,
model_cls: str = "wan2.2_moe_distill", model_cls: str = "wan2.2_moe_distill",
resolution: int = 480, resolution: int = 480,
fps: int = 16, fps: int = 16,
dit_quant_scheme: str = "fp8-sgl",
t5_quantized: bool = False,
): ):
self.base_repo = base_repo self.base_repo = base_repo
self.fp8_repo = fp8_repo self.dit_repo = dit_repo
self.config_json_template = config_json self.config_json_template = config_json
self.model_cls = model_cls self.model_cls = model_cls
self.resolution = resolution self.resolution = resolution
self.fps = fps self.fps = fps
self.dit_quant_scheme = dit_quant_scheme
self.t5_quantized = t5_quantized
self._applied_loras: list[LoRASpec] = [] self._applied_loras: list[LoRASpec] = []
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts. self._is_gguf = dit_quant_scheme.startswith("gguf-")
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
self._model_root = self._ensure_base_repo(base_repo) self._model_root = self._ensure_base_repo(base_repo)
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo) self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
dit_repo, dit_quant_scheme,
)
self._t5_fp8_ckpt = (
self._ensure_t5_fp8() if t5_quantized else None
)
# 2. Materialize a runtime JSON config with absolute ckpt paths. # 2. Materialize a runtime JSON config with absolute ckpt paths.
self._runtime_json_path = self._build_runtime_config() self._runtime_json_path = self._build_runtime_config()
@@ -105,13 +187,17 @@ class Wan22Pipeline:
config_json=self._runtime_json_path, config_json=self._runtime_json_path,
) )
# 4. set_config → init_runner. Runner construction triggers weight load. # 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
# Imports are scoped here so ``import server.video_models.wan22`` # never pulls in lightx2v — tests can import this module on CPU).
# never pulls in lightx2v (tests can import this module on CPU).
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found] from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found] from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
from lightx2v.infer import init_runner # type: ignore[import-not-found] from lightx2v.infer import init_runner # type: ignore[import-not-found]
_patch_fp8_scaled_mm_for_blackwell()
# 5. Load all models under default DTYPE=BF16 so T5 (which is
# hardcoded to bf16 weights) initialises its offload buffers
# correctly. We flip to FP16 *after* init_runner completes.
log.info("LightX2V set_config (model_cls=%s, model_path=%s)", log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
model_cls, self._model_root) model_cls, self._model_root)
self._config = set_config(args) self._config = set_config(args)
@@ -124,6 +210,52 @@ class Wan22Pipeline:
self._runner = init_runner(self._config) self._runner = init_runner(self._config)
log.info("LightX2V runner loaded; weights resident.") log.info("LightX2V runner loaded; weights resident.")
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
# dequantises to fp16, and many intermediate tensors inside the
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
# is wrapped to temporarily restore BF16 during its forward.
if self._is_gguf:
os.environ["DTYPE"] = "FP16"
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
GET_DTYPE.cache_clear()
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
self._patch_t5_dtype_for_gguf()
# --- GGUF dtype compatibility patch ----------------------------------------
def _patch_t5_dtype_for_gguf(self) -> None:
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
path breaks because intermediate tensor dtypes no longer match the bf16
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
back to bf16, then restore fp16 before the DIT runs.
"""
import os
import types
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
runner = self._runner
orig_run_text_encoder = runner.run_text_encoder.__func__
def bf16_text_encoder(self_runner, *args, **kwargs):
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
os.environ["DTYPE"] = "BF16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
try:
result = orig_run_text_encoder(self_runner, *args, **kwargs)
finally:
# Restore FP16 for the DIT / rest of the pipeline.
os.environ["DTYPE"] = "FP16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
return result
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
# --- Weight provisioning ------------------------------------------------- # --- Weight provisioning -------------------------------------------------
@staticmethod @staticmethod
@@ -132,7 +264,7 @@ class Wan22Pipeline:
If ``base_repo`` is already a local directory, use it as-is. Otherwise If ``base_repo`` is already a local directory, use it as-is. Otherwise
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
shards (they're replaced by the fp8 files). shards (they're replaced by the quantised files).
""" """
if os.path.isdir(base_repo): if os.path.isdir(base_repo):
return base_repo return base_repo
@@ -145,42 +277,75 @@ class Wan22Pipeline:
) )
@staticmethod @staticmethod
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]: def _ensure_dit_checkpoints(
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair. dit_repo: str,
dit_quant_scheme: str,
) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the DIT pair.
- If ``fp8_repo`` is a local directory, expect both files inside it. Supports both fp8 safetensors and GGUF formats.
- Otherwise treat it as a HF repo id and download only the two files
we need (not the ~150 GB of other variants in that repo).
""" """
if not fp8_repo: if not dit_repo:
raise ValueError("fp8_repo must be a HF repo id or local directory.") raise ValueError("dit_repo must be a HF repo id or local directory.")
if os.path.isdir(fp8_repo):
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE) is_gguf = dit_quant_scheme.startswith("gguf-")
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
if is_gguf:
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
quant = dit_quant_scheme.replace("gguf-", "")
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
else:
high_file = FP8_HIGH_NOISE_FILE
low_file = FP8_LOW_NOISE_FILE
# Local directory?
if os.path.isdir(dit_repo):
high = os.path.join(dit_repo, high_file)
low = os.path.join(dit_repo, low_file)
if not (os.path.isfile(high) and os.path.isfile(low)): if not (os.path.isfile(high) and os.path.isfile(low)):
raise FileNotFoundError( raise FileNotFoundError(
f"fp8 checkpoints not found in {fp8_repo}: expected " f"DIT checkpoints not found in {dit_repo}: expected "
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}" f"{high_file} and {low_file}"
) )
return high, low return high, low
# HuggingFace download.
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo) log.info("Downloading %s DIT checkpoints from %s ...",
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE) dit_quant_scheme, dit_repo)
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE) high = hf_hub_download(repo_id=dit_repo, filename=high_file)
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
return high, low return high, low
@staticmethod
def _ensure_t5_fp8() -> str:
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
Returns the local path to the safetensors file (~6 GB).
"""
from huggingface_hub import hf_hub_download
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
def _build_runtime_config(self) -> str: def _build_runtime_config(self) -> str:
"""Load the template JSON, inject absolute ckpt paths, persist to temp.""" """Load the template JSON, inject absolute ckpt paths, persist to temp."""
with open(self.config_json_template, "r", encoding="utf-8") as f: with open(self.config_json_template, "r", encoding="utf-8") as f:
cfg = json.load(f) cfg = json.load(f)
# Drop editorial comments before passing to LightX2V. # Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None) cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._fp8_high cfg["high_noise_quantized_ckpt"] = self._dit_high
cfg["low_noise_quantized_ckpt"] = self._fp8_low cfg["low_noise_quantized_ckpt"] = self._dit_low
cfg.setdefault("fps", self.fps) cfg.setdefault("fps", self.fps)
# T5 fp8 quantization.
if self._t5_fp8_ckpt:
cfg["t5_quantized"] = True
cfg["t5_quant_scheme"] = "fp8-sgl"
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
tmp = tempfile.NamedTemporaryFile( tmp = tempfile.NamedTemporaryFile(
prefix="wan22_fp8_", suffix=".json", prefix="wan22_dit_", suffix=".json",
mode="w", delete=False, encoding="utf-8", mode="w", delete=False, encoding="utf-8",
) )
json.dump(cfg, tmp, indent=2) json.dump(cfg, tmp, indent=2)
Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

+27 -9
View File
@@ -1,15 +1,22 @@
"""Phase 2 component test: Wan2.2-Lightning fp8 pipeline + LoRA stacking. """Phase 2 component test: Wan2.2 pipeline + LoRA stacking.
Verifies: Verifies:
- ``Wan22Pipeline`` loads successfully against the fp8 distill path - ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
(exercises the real LightX2V set_config init_runner flow). set_config -> init_runner flow).
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at - ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
``/cache/loras/wan22-[HL]-e8.safetensors``. ``/cache/loras/wan22-[HL]-e8.safetensors``.
Requires GPU and a first-run download of both HF repos (base support files Supports both fp8 and GGUF DIT quantisation. Set the ``DIT_QUANT``
~12 GB, fp8 DIT ~30 GB). If LightX2V isn't installed the test is skipped. environment variable to switch (default: ``fp8-sgl``).
Run: DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
python -m tests.component.test_02_wan22_loras
Requires GPU and a first-run download of both HF repos (base support files
~12 GB, DIT size depends on quant — fp8 ~30 GB, GGUF Q4_K_M ~19 GB).
If LightX2V isn't installed the test is skipped.
Run (default fp8):
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
""" """
from __future__ import annotations from __future__ import annotations
@@ -21,7 +28,17 @@ from tests.component._common import get_logger
log = get_logger("test_02") log = get_logger("test_02")
# --- Quant-dependent defaults ------------------------------------------------
DIT_QUANT = os.environ.get("DIT_QUANT", "fp8-sgl")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors" LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors" LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
@@ -37,15 +54,16 @@ def run():
from server.video import LoRASpec from server.video import LoRASpec
log.info("[case 1] Instantiate Wan22Pipeline " log.info("[case 1] Instantiate Wan22Pipeline "
"(first run downloads ~42 GB total)...") "(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
try: try:
pipe = Wan22Pipeline( pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B", base_repo="Wan-AI/Wan2.2-I2V-A14B",
fp8_repo="lightx2v/Wan2.2-Distill-Models", dit_repo=DIT_REPO,
config_json=CONFIG_JSON, config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill", model_cls="wan2.2_moe_distill",
resolution=480, resolution=480,
fps=16, fps=16,
dit_quant_scheme=DIT_QUANT,
) )
except Exception as e: except Exception as e:
log.error("FAIL: Wan22Pipeline construction raised: %s", e) log.error("FAIL: Wan22Pipeline construction raised: %s", e)
@@ -56,7 +74,7 @@ def run():
log.info(" PASS: pipeline constructed") log.info(" PASS: pipeline constructed")
# --- LoRAs --- # --- LoRAs ---
log.info("[case 2] load_loras with empty list no-op") log.info("[case 2] load_loras with empty list -> no-op")
pipe.load_loras([]) pipe.load_loras([])
log.info(" PASS") log.info(" PASS")
+83
View File
@@ -0,0 +1,83 @@
"""Quick smoke test: generate a video clip with the GGUF pipeline.
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
and writes the result to tests/component/_out/phase9_gguf.mp4.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_09_gguf_generate
"""
from __future__ import annotations
import os
import sys
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_09")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
# Debug: verify DTYPE is set correctly for GGUF
from lightx2v.utils.envs import GET_DTYPE
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
log.info("Generating 3-second i2v clip...")
frames = pipe.generate_i2v(
image_path=avatar,
prompt="a person looking at the camera, natural lighting, soft focus background",
seconds=3,
seed=42,
)
log.info("Got frames: shape=%s dtype=%s", frames.shape, frames.dtype)
# Encode to MP4
import imageio.v3 as iio
import tempfile
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
tmp = tf.name
try:
iio.imwrite(tmp, frames, fps=16, codec="libx264")
with open(tmp, "rb") as f:
mp4_bytes = f.read()
finally:
os.remove(tmp)
out = write_bytes("phase9_gguf.mp4", mp4_bytes)
log.info("PASS: video written to %s (%d bytes, %d frames)", out, len(mp4_bytes), frames.shape[0])
if __name__ == "__main__":
run()
+88
View File
@@ -0,0 +1,88 @@
"""Smoke test: T5 text encoding under GGUF pipeline.
Builds the Wan22Pipeline (loads all weights including DIT) but only
exercises the T5 encoder — no image-to-video generation. Validates that
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_10_t5_encode
"""
from __future__ import annotations
import os
import sys
from tests.component._common import get_logger
log = get_logger("test_10")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
# Check DTYPE state after init
from lightx2v.utils.envs import GET_DTYPE
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
# Run only the T5 text encoder
runner = pipe._runner
prompt = "a person looking at the camera, natural lighting"
log.info("Running T5 text encoder on prompt: %r", prompt)
import copy
input_info = copy.deepcopy(pipe._input_info_template)
input_info.prompt = prompt
runner.run_text_encoder(input_info)
log.info("T5 encode complete.")
# Inspect output — check all dataclass fields for tensor results
import torch
for attr in vars(input_info):
val = getattr(input_info, attr)
if isinstance(val, torch.Tensor):
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
if DIT_QUANT.startswith("gguf-"):
current = GET_DTYPE()
import torch
expected = torch.float16
if current == expected:
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
else:
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
sys.exit(1)
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
+16 -2
View File
@@ -105,7 +105,8 @@ def test_models_section_override():
{ {
"models": { "models": {
"wan22_base_repo": "/local/weights/wan22", "wan22_base_repo": "/local/weights/wan22",
"wan22_fp8_repo": "/local/weights/wan22-fp8", "wan22_dit_repo": "/local/weights/wan22-dit",
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
"wan22_config_json": "/local/cfg/fp8.json", "wan22_config_json": "/local/cfg/fp8.json",
"wan22_model_cls": "wan2.2_moe", "wan22_model_cls": "wan2.2_moe",
"musetalk_path": "/local/weights/musetalk", "musetalk_path": "/local/weights/musetalk",
@@ -113,7 +114,20 @@ def test_models_section_override():
} }
) )
assert cfg.wan22_base_repo == "/local/weights/wan22" assert cfg.wan22_base_repo == "/local/weights/wan22"
assert cfg.wan22_fp8_repo == "/local/weights/wan22-fp8" assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
assert cfg.wan22_config_json == "/local/cfg/fp8.json" assert cfg.wan22_config_json == "/local/cfg/fp8.json"
assert cfg.wan22_model_cls == "wan2.2_moe" assert cfg.wan22_model_cls == "wan2.2_moe"
assert cfg.musetalk_model_path == "/local/weights/musetalk" assert cfg.musetalk_model_path == "/local/weights/musetalk"
def test_models_section_backwards_compat_fp8_repo():
"""Old config key wan22_fp8_repo still works via fallback."""
cfg = VideoConfig.from_dict(
{
"models": {
"wan22_fp8_repo": "/local/weights/wan22-fp8",
}
}
)
assert cfg.wan22_dit_repo == "/local/weights/wan22-fp8"