t5 encoder fp8 seems to be working
This commit is contained in:
@@ -53,6 +53,14 @@ RUN python3.11 -m pip install --no-cache-dir \
|
|||||||
"git+https://github.com/ModelTC/LightX2V.git" || \
|
"git+https://github.com/ModelTC/LightX2V.git" || \
|
||||||
echo "LightX2V install failed — config.video.enabled must stay false until fixed"
|
echo "LightX2V install failed — config.video.enabled must stay false until fixed"
|
||||||
#
|
#
|
||||||
|
# sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120
|
||||||
|
# (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead.
|
||||||
|
# Our wan22.py patches fp8_scaled_mm → torch._scaled_mm at runtime for
|
||||||
|
# Blackwell GPUs, but the sgl_kernel package itself must still be present.
|
||||||
|
RUN python3.11 -m pip install --no-cache-dir --no-deps \
|
||||||
|
"sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \
|
||||||
|
echo "sgl-kernel install failed — fp8 T5 will fall back to bf16"
|
||||||
|
#
|
||||||
# MuseTalk (audio-driven lip-sync) — same story.
|
# MuseTalk (audio-driven lip-sync) — same story.
|
||||||
RUN python3.11 -m pip install --no-cache-dir \
|
RUN python3.11 -m pip install --no-cache-dir \
|
||||||
"git+https://github.com/TMElyralab/MuseTalk.git" || \
|
"git+https://github.com/TMElyralab/MuseTalk.git" || \
|
||||||
|
|||||||
+14
-7
@@ -32,16 +32,23 @@ video:
|
|||||||
casual gestures, natural lighting, soft focus background
|
casual gestures, natural lighting, soft focus background
|
||||||
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
|
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
|
||||||
|
|
||||||
# Model sources for the video stack. The fp8 e4m3 4-step distilled DIT
|
# Model sources for the video stack. T5/VAE/tokenizer come from the
|
||||||
# weights from lightx2v/Wan2.2-Distill-Models are ~15 GB each (vs ~28 GB
|
# Wan-AI base repo. DIT weights come from wan22_dit_repo in the format
|
||||||
# bf16) — that's the "save VRAM" path. T5/VAE/tokenizer still come from
|
# specified by wan22_dit_quant_scheme. Both repos download on first run
|
||||||
# the Wan-AI base repo. Both repos download on first run into
|
# into HF_HOME=/cache/huggingface.
|
||||||
# HF_HOME=/cache/huggingface.
|
#
|
||||||
|
# Supported dit_quant_scheme values:
|
||||||
|
# fp8-sgl — fp8 e4m3 safetensors (~15 GB/expert, from lightx2v/Wan2.2-Distill-Models)
|
||||||
|
# gguf-Q4_K_M — GGUF 4-bit (~9.65 GB/expert, from QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||||
|
# gguf-Q8_0 — GGUF 8-bit (~15.4 GB/expert)
|
||||||
|
# (any gguf-<level> supported by LightX2V — see base_model.py MM_WEIGHT_REGISTER)
|
||||||
models:
|
models:
|
||||||
wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B
|
wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B
|
||||||
wan22_fp8_repo: lightx2v/Wan2.2-Distill-Models
|
wan22_dit_repo: QuantStack/Wan2.2-I2V-A14B-GGUF
|
||||||
|
wan22_dit_quant_scheme: gguf-Q4_K_M
|
||||||
|
wan22_t5_quantized: true
|
||||||
wan22_model_cls: wan2.2_moe_distill
|
wan22_model_cls: wan2.2_moe_distill
|
||||||
wan22_config_json: /app/configs/lightx2v/wan22_i2v_fp8_distill.json
|
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_distill.json
|
||||||
musetalk_path: TMElyralab/MuseTalk
|
musetalk_path: TMElyralab/MuseTalk
|
||||||
|
|
||||||
# LoRAs applied to the fp8 base at load time via runtime switch_lora.
|
# LoRAs applied to the fp8 base at load time via runtime switch_lora.
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
{
|
||||||
|
"_comment": "Wan2.2 i2v MoE 4-step distill, GGUF quantized. Uses QuantStack/Wan2.2-I2V-A14B-GGUF checkpoints instead of fp8 safetensors. GGUF does not support block-level offload so offload_granularity is set to 'model' — the entire DIT is moved to GPU when active. With Q4_K_M (~9.65 GB per expert) this fits comfortably in 24+ GB VRAM. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py. IMPORTANT: GGUF dequantizes to fp16, so you must set DTYPE=FP16 in the container environment.",
|
||||||
|
|
||||||
|
"infer_steps": 4,
|
||||||
|
"target_video_length": 81,
|
||||||
|
"text_len": 512,
|
||||||
|
|
||||||
|
"resize_mode": "adaptive",
|
||||||
|
"resolution": "480p",
|
||||||
|
"target_height": 480,
|
||||||
|
"target_width": 480,
|
||||||
|
"fps": 16,
|
||||||
|
|
||||||
|
"self_attn_1_type": "flash_attn3",
|
||||||
|
"cross_attn_1_type": "flash_attn3",
|
||||||
|
"cross_attn_2_type": "flash_attn3",
|
||||||
|
|
||||||
|
"sample_guide_scale": [3.5, 3.5],
|
||||||
|
"sample_shift": 5.0,
|
||||||
|
"enable_cfg": false,
|
||||||
|
|
||||||
|
"cpu_offload": true,
|
||||||
|
"offload_granularity": "model",
|
||||||
|
"t5_cpu_offload": true,
|
||||||
|
"vae_cpu_offload": false,
|
||||||
|
|
||||||
|
"use_image_encoder": false,
|
||||||
|
|
||||||
|
"boundary_step_index": 2,
|
||||||
|
"denoising_step_list": [1000, 750, 500, 250],
|
||||||
|
|
||||||
|
"dit_quantized": true,
|
||||||
|
"dit_quant_scheme": "gguf-Q4_K_M",
|
||||||
|
"t5_quantized": false
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ services:
|
|||||||
- ./configs:/app/configs:ro
|
- ./configs:/app/configs:ro
|
||||||
- ./server:/app/server:ro
|
- ./server:/app/server:ro
|
||||||
- ./static:/app/static:ro
|
- ./static:/app/static:ro
|
||||||
|
- ./tests:/app/tests
|
||||||
- ./run.py:/app/run.py:ro
|
- ./run.py:/app/run.py:ro
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ pyyaml
|
|||||||
imageio[ffmpeg]>=2.34
|
imageio[ffmpeg]>=2.34
|
||||||
av>=12.0
|
av>=12.0
|
||||||
pyzmq>=25.0
|
pyzmq>=25.0
|
||||||
|
gguf>=0.6.0
|
||||||
|
# sgl-kernel: installed from SGLang's cu128 wheel index in Dockerfile
|
||||||
|
# (PyPI version lacks SM120/Blackwell CUDA kernels)
|
||||||
# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the
|
# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the
|
||||||
# Dockerfile because neither ships a stable PyPI release yet. See lines
|
# Dockerfile because neither ships a stable PyPI release yet. See lines
|
||||||
# "LightX2V from source" / "MuseTalk from source" in Dockerfile.
|
# "LightX2V from source" / "MuseTalk from source" in Dockerfile.
|
||||||
|
|||||||
+25
-9
@@ -60,13 +60,16 @@ class VideoConfig:
|
|||||||
# Model paths — can be overridden via config.yml.video.models.
|
# Model paths — can be overridden via config.yml.video.models.
|
||||||
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
|
||||||
# The bf16 DIT shards in this repo are skipped — we
|
# The bf16 DIT shards in this repo are skipped — we
|
||||||
# replace them with the fp8 files from wan22_fp8_repo.
|
# replace them with quantised files from wan22_dit_repo.
|
||||||
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
|
# wan22_dit_repo : HF repo id (or local dir) providing the quantised
|
||||||
# 4-step distilled DIT checkpoints (~15 GB each).
|
# DIT checkpoints (fp8 or GGUF).
|
||||||
|
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
|
||||||
# wan22_config_json : path to the LightX2V inference config template the
|
# wan22_config_json : path to the LightX2V inference config template the
|
||||||
# Wan22Pipeline will fill in with absolute ckpt paths.
|
# Wan22Pipeline will fill in with absolute ckpt paths.
|
||||||
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
|
||||||
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
wan22_dit_quant_scheme: str = "fp8-sgl"
|
||||||
|
wan22_t5_quantized: bool = False
|
||||||
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
wan22_model_cls: str = "wan2.2_moe_distill"
|
wan22_model_cls: str = "wan2.2_moe_distill"
|
||||||
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
musetalk_model_path: str = "TMElyralab/MuseTalk"
|
||||||
@@ -121,8 +124,18 @@ class VideoConfig:
|
|||||||
wan22_base_repo=str(
|
wan22_base_repo=str(
|
||||||
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
|
||||||
),
|
),
|
||||||
wan22_fp8_repo=str(
|
wan22_dit_repo=str(
|
||||||
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
|
models_raw.get(
|
||||||
|
"wan22_dit_repo",
|
||||||
|
# Backwards compat: fall back to old key name.
|
||||||
|
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
wan22_dit_quant_scheme=str(
|
||||||
|
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
|
||||||
|
),
|
||||||
|
wan22_t5_quantized=bool(
|
||||||
|
models_raw.get("wan22_t5_quantized", False)
|
||||||
),
|
),
|
||||||
wan22_config_json=str(
|
wan22_config_json=str(
|
||||||
models_raw.get(
|
models_raw.get(
|
||||||
@@ -204,16 +217,19 @@ class VideoEngine:
|
|||||||
from server.video_models.musetalk import MuseTalkEngine
|
from server.video_models.musetalk import MuseTalkEngine
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
|
"Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
|
||||||
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
|
self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
|
||||||
|
self.cfg.wan22_dit_quant_scheme,
|
||||||
)
|
)
|
||||||
self._wan22 = Wan22Pipeline(
|
self._wan22 = Wan22Pipeline(
|
||||||
base_repo=self.cfg.wan22_base_repo,
|
base_repo=self.cfg.wan22_base_repo,
|
||||||
fp8_repo=self.cfg.wan22_fp8_repo,
|
dit_repo=self.cfg.wan22_dit_repo,
|
||||||
config_json=self.cfg.wan22_config_json,
|
config_json=self.cfg.wan22_config_json,
|
||||||
model_cls=self.cfg.wan22_model_cls,
|
model_cls=self.cfg.wan22_model_cls,
|
||||||
resolution=self.cfg.resolution,
|
resolution=self.cfg.resolution,
|
||||||
fps=self.cfg.fps,
|
fps=self.cfg.fps,
|
||||||
|
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
|
||||||
|
t5_quantized=self.cfg.wan22_t5_quantized,
|
||||||
)
|
)
|
||||||
if self.cfg.loras:
|
if self.cfg.loras:
|
||||||
self._wan22.load_loras(self.cfg.loras)
|
self._wan22.load_loras(self.cfg.loras)
|
||||||
|
|||||||
+200
-35
@@ -1,4 +1,4 @@
|
|||||||
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
|
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
|
||||||
|
|
||||||
This wrapper targets LightX2V's actual Python entry points (verified against
|
This wrapper targets LightX2V's actual Python entry points (verified against
|
||||||
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
|
||||||
@@ -25,10 +25,12 @@ Two HuggingFace repos are consumed on first run (cached under HF_HOME):
|
|||||||
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
|
||||||
The bf16 DIT shards under high_noise_model/
|
The bf16 DIT shards under high_noise_model/
|
||||||
and low_noise_model/ are SKIPPED via
|
and low_noise_model/ are SKIPPED via
|
||||||
ignore_patterns — we replace them with fp8.
|
ignore_patterns — we replace them with
|
||||||
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
|
quantised checkpoints from dit_repo.
|
||||||
the fp8 e4m3 4-step distilled high/low
|
- dit_repo (configurable) — quantised DIT checkpoints. Supported
|
||||||
noise DIT checkpoints (~15 GB each).
|
formats:
|
||||||
|
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
|
||||||
|
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -47,13 +49,22 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# --- fp8 distill filenames --------------------------------------------------
|
||||||
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||||
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
|
||||||
|
|
||||||
|
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
|
||||||
|
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
|
||||||
|
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
|
||||||
|
|
||||||
|
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
|
||||||
|
T5_FP8_REPO = "lightx2v/Encoders"
|
||||||
|
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
|
||||||
|
|
||||||
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
|
||||||
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
|
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
|
||||||
# files from the distill repo replace the DIT weights entirely. We must keep
|
# quantised files from dit_repo replace the DIT weights entirely. We must
|
||||||
# the config.json / index.json metadata under high_noise_model/ and
|
# keep the config.json / index.json metadata under high_noise_model/ and
|
||||||
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
# low_noise_model/ (LightX2V's set_config reads architecture params like
|
||||||
# ``dim`` from them) and the tokenizer files under google/.
|
# ``dim`` from them) and the tokenizer files under google/.
|
||||||
BASE_REPO_IGNORE_PATTERNS = [
|
BASE_REPO_IGNORE_PATTERNS = [
|
||||||
@@ -66,8 +77,68 @@ BASE_REPO_IGNORE_PATTERNS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_fp8_scaled_mm_for_blackwell() -> None:
|
||||||
|
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
|
||||||
|
|
||||||
|
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
|
||||||
|
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
|
||||||
|
including Blackwell. This patch is idempotent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import sgl_kernel # type: ignore[import-not-found]
|
||||||
|
except ImportError:
|
||||||
|
return # no sgl_kernel → fp8 T5 not in use
|
||||||
|
|
||||||
|
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
cap = torch.cuda.get_device_capability()
|
||||||
|
if cap[0] < 12:
|
||||||
|
return # only patch on Blackwell+
|
||||||
|
|
||||||
|
_orig = sgl_kernel.fp8_scaled_mm
|
||||||
|
|
||||||
|
def _torch_fp8_scaled_mm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
b_scale: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
|
||||||
|
# scale_a: scalar or (M,1)
|
||||||
|
# scale_b: scalar or (1,N)
|
||||||
|
# sgl_kernel provides scale_b as (N,1) — transpose it.
|
||||||
|
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
|
||||||
|
b_scale = b_scale.t()
|
||||||
|
# _scaled_mm requires B to be column-major (stride(0)==1).
|
||||||
|
bt = b.t().contiguous().t()
|
||||||
|
out = torch._scaled_mm(
|
||||||
|
a, bt,
|
||||||
|
scale_a=a_scale, scale_b=b_scale,
|
||||||
|
out_dtype=out_dtype, bias=bias,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
|
||||||
|
sgl_kernel._fp8_patched_for_blackwell = True
|
||||||
|
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
|
||||||
|
|
||||||
|
|
||||||
class Wan22Pipeline:
|
class Wan22Pipeline:
|
||||||
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
|
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
|
||||||
|
|
||||||
|
Supports two DIT quantisation formats:
|
||||||
|
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
|
||||||
|
``lightx2v/Wan2.2-Distill-Models``)
|
||||||
|
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
|
||||||
|
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
|
||||||
|
|
||||||
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
Constructor downloads (if needed) both HF repos, writes a runtime JSON
|
||||||
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
|
||||||
@@ -77,23 +148,34 @@ class Wan22Pipeline:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_repo: str,
|
base_repo: str,
|
||||||
fp8_repo: str,
|
dit_repo: str,
|
||||||
config_json: str,
|
config_json: str,
|
||||||
model_cls: str = "wan2.2_moe_distill",
|
model_cls: str = "wan2.2_moe_distill",
|
||||||
resolution: int = 480,
|
resolution: int = 480,
|
||||||
fps: int = 16,
|
fps: int = 16,
|
||||||
|
dit_quant_scheme: str = "fp8-sgl",
|
||||||
|
t5_quantized: bool = False,
|
||||||
):
|
):
|
||||||
self.base_repo = base_repo
|
self.base_repo = base_repo
|
||||||
self.fp8_repo = fp8_repo
|
self.dit_repo = dit_repo
|
||||||
self.config_json_template = config_json
|
self.config_json_template = config_json
|
||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
|
self.dit_quant_scheme = dit_quant_scheme
|
||||||
|
self.t5_quantized = t5_quantized
|
||||||
self._applied_loras: list[LoRASpec] = []
|
self._applied_loras: list[LoRASpec] = []
|
||||||
|
|
||||||
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
|
self._is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||||
|
|
||||||
|
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
|
||||||
self._model_root = self._ensure_base_repo(base_repo)
|
self._model_root = self._ensure_base_repo(base_repo)
|
||||||
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
|
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
|
||||||
|
dit_repo, dit_quant_scheme,
|
||||||
|
)
|
||||||
|
self._t5_fp8_ckpt = (
|
||||||
|
self._ensure_t5_fp8() if t5_quantized else None
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
# 2. Materialize a runtime JSON config with absolute ckpt paths.
|
||||||
self._runtime_json_path = self._build_runtime_config()
|
self._runtime_json_path = self._build_runtime_config()
|
||||||
@@ -105,13 +187,17 @@ class Wan22Pipeline:
|
|||||||
config_json=self._runtime_json_path,
|
config_json=self._runtime_json_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. set_config → init_runner. Runner construction triggers weight load.
|
# 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
|
||||||
# Imports are scoped here so ``import server.video_models.wan22``
|
# never pulls in lightx2v — tests can import this module on CPU).
|
||||||
# never pulls in lightx2v (tests can import this module on CPU).
|
|
||||||
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
|
||||||
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
|
||||||
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
from lightx2v.infer import init_runner # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
_patch_fp8_scaled_mm_for_blackwell()
|
||||||
|
|
||||||
|
# 5. Load all models under default DTYPE=BF16 so T5 (which is
|
||||||
|
# hardcoded to bf16 weights) initialises its offload buffers
|
||||||
|
# correctly. We flip to FP16 *after* init_runner completes.
|
||||||
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
|
||||||
model_cls, self._model_root)
|
model_cls, self._model_root)
|
||||||
self._config = set_config(args)
|
self._config = set_config(args)
|
||||||
@@ -124,6 +210,52 @@ class Wan22Pipeline:
|
|||||||
self._runner = init_runner(self._config)
|
self._runner = init_runner(self._config)
|
||||||
log.info("LightX2V runner loaded; weights resident.")
|
log.info("LightX2V runner loaded; weights resident.")
|
||||||
|
|
||||||
|
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
|
||||||
|
# dequantises to fp16, and many intermediate tensors inside the
|
||||||
|
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
|
||||||
|
# is wrapped to temporarily restore BF16 during its forward.
|
||||||
|
if self._is_gguf:
|
||||||
|
os.environ["DTYPE"] = "FP16"
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
|
||||||
|
self._patch_t5_dtype_for_gguf()
|
||||||
|
|
||||||
|
# --- GGUF dtype compatibility patch ----------------------------------------
|
||||||
|
|
||||||
|
def _patch_t5_dtype_for_gguf(self) -> None:
|
||||||
|
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
|
||||||
|
|
||||||
|
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
|
||||||
|
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
|
||||||
|
path breaks because intermediate tensor dtypes no longer match the bf16
|
||||||
|
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
|
||||||
|
back to bf16, then restore fp16 before the DIT runs.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
runner = self._runner
|
||||||
|
orig_run_text_encoder = runner.run_text_encoder.__func__
|
||||||
|
|
||||||
|
def bf16_text_encoder(self_runner, *args, **kwargs):
|
||||||
|
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
|
||||||
|
os.environ["DTYPE"] = "BF16"
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
GET_SENSITIVE_DTYPE.cache_clear()
|
||||||
|
try:
|
||||||
|
result = orig_run_text_encoder(self_runner, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Restore FP16 for the DIT / rest of the pipeline.
|
||||||
|
os.environ["DTYPE"] = "FP16"
|
||||||
|
GET_DTYPE.cache_clear()
|
||||||
|
GET_SENSITIVE_DTYPE.cache_clear()
|
||||||
|
return result
|
||||||
|
|
||||||
|
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
|
||||||
|
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
|
||||||
|
|
||||||
# --- Weight provisioning -------------------------------------------------
|
# --- Weight provisioning -------------------------------------------------
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -132,7 +264,7 @@ class Wan22Pipeline:
|
|||||||
|
|
||||||
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
If ``base_repo`` is already a local directory, use it as-is. Otherwise
|
||||||
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
|
||||||
shards (they're replaced by the fp8 files).
|
shards (they're replaced by the quantised files).
|
||||||
"""
|
"""
|
||||||
if os.path.isdir(base_repo):
|
if os.path.isdir(base_repo):
|
||||||
return base_repo
|
return base_repo
|
||||||
@@ -145,42 +277,75 @@ class Wan22Pipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
|
def _ensure_dit_checkpoints(
|
||||||
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
|
dit_repo: str,
|
||||||
|
dit_quant_scheme: str,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Return (high_noise_path, low_noise_path) for the DIT pair.
|
||||||
|
|
||||||
- If ``fp8_repo`` is a local directory, expect both files inside it.
|
Supports both fp8 safetensors and GGUF formats.
|
||||||
- Otherwise treat it as a HF repo id and download only the two files
|
|
||||||
we need (not the ~150 GB of other variants in that repo).
|
|
||||||
"""
|
"""
|
||||||
if not fp8_repo:
|
if not dit_repo:
|
||||||
raise ValueError("fp8_repo must be a HF repo id or local directory.")
|
raise ValueError("dit_repo must be a HF repo id or local directory.")
|
||||||
if os.path.isdir(fp8_repo):
|
|
||||||
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
|
is_gguf = dit_quant_scheme.startswith("gguf-")
|
||||||
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
|
|
||||||
|
if is_gguf:
|
||||||
|
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
|
||||||
|
quant = dit_quant_scheme.replace("gguf-", "")
|
||||||
|
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
|
||||||
|
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
|
||||||
|
else:
|
||||||
|
high_file = FP8_HIGH_NOISE_FILE
|
||||||
|
low_file = FP8_LOW_NOISE_FILE
|
||||||
|
|
||||||
|
# Local directory?
|
||||||
|
if os.path.isdir(dit_repo):
|
||||||
|
high = os.path.join(dit_repo, high_file)
|
||||||
|
low = os.path.join(dit_repo, low_file)
|
||||||
if not (os.path.isfile(high) and os.path.isfile(low)):
|
if not (os.path.isfile(high) and os.path.isfile(low)):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"fp8 checkpoints not found in {fp8_repo}: expected "
|
f"DIT checkpoints not found in {dit_repo}: expected "
|
||||||
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
|
f"{high_file} and {low_file}"
|
||||||
)
|
)
|
||||||
return high, low
|
return high, low
|
||||||
|
|
||||||
|
# HuggingFace download.
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
|
log.info("Downloading %s DIT checkpoints from %s ...",
|
||||||
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
|
dit_quant_scheme, dit_repo)
|
||||||
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
|
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
|
||||||
|
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
|
||||||
return high, low
|
return high, low
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_t5_fp8() -> str:
|
||||||
|
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
|
||||||
|
|
||||||
|
Returns the local path to the safetensors file (~6 GB).
|
||||||
|
"""
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
|
||||||
|
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
|
||||||
|
|
||||||
def _build_runtime_config(self) -> str:
|
def _build_runtime_config(self) -> str:
|
||||||
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
|
||||||
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
with open(self.config_json_template, "r", encoding="utf-8") as f:
|
||||||
cfg = json.load(f)
|
cfg = json.load(f)
|
||||||
# Drop editorial comments before passing to LightX2V.
|
# Drop editorial comments before passing to LightX2V.
|
||||||
cfg.pop("_comment", None)
|
cfg.pop("_comment", None)
|
||||||
cfg["high_noise_quantized_ckpt"] = self._fp8_high
|
cfg["high_noise_quantized_ckpt"] = self._dit_high
|
||||||
cfg["low_noise_quantized_ckpt"] = self._fp8_low
|
cfg["low_noise_quantized_ckpt"] = self._dit_low
|
||||||
cfg.setdefault("fps", self.fps)
|
cfg.setdefault("fps", self.fps)
|
||||||
|
|
||||||
|
# T5 fp8 quantization.
|
||||||
|
if self._t5_fp8_ckpt:
|
||||||
|
cfg["t5_quantized"] = True
|
||||||
|
cfg["t5_quant_scheme"] = "fp8-sgl"
|
||||||
|
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
|
||||||
|
|
||||||
tmp = tempfile.NamedTemporaryFile(
|
tmp = tempfile.NamedTemporaryFile(
|
||||||
prefix="wan22_fp8_", suffix=".json",
|
prefix="wan22_dit_", suffix=".json",
|
||||||
mode="w", delete=False, encoding="utf-8",
|
mode="w", delete=False, encoding="utf-8",
|
||||||
)
|
)
|
||||||
json.dump(cfg, tmp, indent=2)
|
json.dump(cfg, tmp, indent=2)
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
@@ -1,15 +1,22 @@
|
|||||||
"""Phase 2 component test: Wan2.2-Lightning fp8 pipeline + LoRA stacking.
|
"""Phase 2 component test: Wan2.2 pipeline + LoRA stacking.
|
||||||
|
|
||||||
Verifies:
|
Verifies:
|
||||||
- ``Wan22Pipeline`` loads successfully against the fp8 distill path
|
- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
|
||||||
(exercises the real LightX2V set_config → init_runner flow).
|
set_config -> init_runner flow).
|
||||||
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
|
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
|
||||||
``/cache/loras/wan22-[HL]-e8.safetensors``.
|
``/cache/loras/wan22-[HL]-e8.safetensors``.
|
||||||
|
|
||||||
Requires GPU and a first-run download of both HF repos (base support files
|
Supports both fp8 and GGUF DIT quantisation. Set the ``DIT_QUANT``
|
||||||
~12 GB, fp8 DIT ~30 GB). If LightX2V isn't installed the test is skipped.
|
environment variable to switch (default: ``fp8-sgl``).
|
||||||
|
|
||||||
Run:
|
DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
|
||||||
|
python -m tests.component.test_02_wan22_loras
|
||||||
|
|
||||||
|
Requires GPU and a first-run download of both HF repos (base support files
|
||||||
|
~12 GB, DIT size depends on quant — fp8 ~30 GB, GGUF Q4_K_M ~19 GB).
|
||||||
|
If LightX2V isn't installed the test is skipped.
|
||||||
|
|
||||||
|
Run (default fp8):
|
||||||
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
|
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -21,7 +28,17 @@ from tests.component._common import get_logger
|
|||||||
|
|
||||||
log = get_logger("test_02")
|
log = get_logger("test_02")
|
||||||
|
|
||||||
|
# --- Quant-dependent defaults ------------------------------------------------
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "fp8-sgl")
|
||||||
|
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||||
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||||
|
else:
|
||||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
|
||||||
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
|
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
|
||||||
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
|
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
|
||||||
|
|
||||||
@@ -37,15 +54,16 @@ def run():
|
|||||||
from server.video import LoRASpec
|
from server.video import LoRASpec
|
||||||
|
|
||||||
log.info("[case 1] Instantiate Wan22Pipeline "
|
log.info("[case 1] Instantiate Wan22Pipeline "
|
||||||
"(first run downloads ~42 GB total)...")
|
"(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
|
||||||
try:
|
try:
|
||||||
pipe = Wan22Pipeline(
|
pipe = Wan22Pipeline(
|
||||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||||
fp8_repo="lightx2v/Wan2.2-Distill-Models",
|
dit_repo=DIT_REPO,
|
||||||
config_json=CONFIG_JSON,
|
config_json=CONFIG_JSON,
|
||||||
model_cls="wan2.2_moe_distill",
|
model_cls="wan2.2_moe_distill",
|
||||||
resolution=480,
|
resolution=480,
|
||||||
fps=16,
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
|
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
|
||||||
@@ -56,7 +74,7 @@ def run():
|
|||||||
log.info(" PASS: pipeline constructed")
|
log.info(" PASS: pipeline constructed")
|
||||||
|
|
||||||
# --- LoRAs ---
|
# --- LoRAs ---
|
||||||
log.info("[case 2] load_loras with empty list → no-op")
|
log.info("[case 2] load_loras with empty list -> no-op")
|
||||||
pipe.load_loras([])
|
pipe.load_loras([])
|
||||||
log.info(" PASS")
|
log.info(" PASS")
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""Quick smoke test: generate a video clip with the GGUF pipeline.
|
||||||
|
|
||||||
|
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
|
||||||
|
and writes the result to tests/component/_out/phase9_gguf.mp4.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_09_gguf_generate
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||||
|
|
||||||
|
log = get_logger("test_09")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||||
|
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||||
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||||
|
else:
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
avatar = ensure_sample_avatar()
|
||||||
|
log.info("Avatar: %s", avatar)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2_moe_distill",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
# Debug: verify DTYPE is set correctly for GGUF
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE
|
||||||
|
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||||
|
|
||||||
|
log.info("Generating 3-second i2v clip...")
|
||||||
|
frames = pipe.generate_i2v(
|
||||||
|
image_path=avatar,
|
||||||
|
prompt="a person looking at the camera, natural lighting, soft focus background",
|
||||||
|
seconds=3,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
log.info("Got frames: shape=%s dtype=%s", frames.shape, frames.dtype)
|
||||||
|
|
||||||
|
# Encode to MP4
|
||||||
|
import imageio.v3 as iio
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||||
|
tmp = tf.name
|
||||||
|
try:
|
||||||
|
iio.imwrite(tmp, frames, fps=16, codec="libx264")
|
||||||
|
with open(tmp, "rb") as f:
|
||||||
|
mp4_bytes = f.read()
|
||||||
|
finally:
|
||||||
|
os.remove(tmp)
|
||||||
|
|
||||||
|
out = write_bytes("phase9_gguf.mp4", mp4_bytes)
|
||||||
|
log.info("PASS: video written to %s (%d bytes, %d frames)", out, len(mp4_bytes), frames.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
"""Smoke test: T5 text encoding under GGUF pipeline.
|
||||||
|
|
||||||
|
Builds the Wan22Pipeline (loads all weights including DIT) but only
|
||||||
|
exercises the T5 encoder — no image-to-video generation. Validates that
|
||||||
|
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||||
|
python -m tests.component.test_10_t5_encode
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.component._common import get_logger
|
||||||
|
|
||||||
|
log = get_logger("test_10")
|
||||||
|
|
||||||
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||||
|
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||||
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||||
|
else:
|
||||||
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||||
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
from server.video_models.wan22 import Wan22Pipeline
|
||||||
|
except ImportError as e:
|
||||||
|
log.error("Import failed: %s", e)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
|
||||||
|
pipe = Wan22Pipeline(
|
||||||
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||||
|
dit_repo=DIT_REPO,
|
||||||
|
config_json=CONFIG_JSON,
|
||||||
|
model_cls="wan2.2_moe_distill",
|
||||||
|
resolution=480,
|
||||||
|
fps=16,
|
||||||
|
dit_quant_scheme=DIT_QUANT,
|
||||||
|
t5_quantized=True,
|
||||||
|
)
|
||||||
|
log.info("Pipeline ready.")
|
||||||
|
|
||||||
|
# Check DTYPE state after init
|
||||||
|
from lightx2v.utils.envs import GET_DTYPE
|
||||||
|
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||||
|
|
||||||
|
# Run only the T5 text encoder
|
||||||
|
runner = pipe._runner
|
||||||
|
prompt = "a person looking at the camera, natural lighting"
|
||||||
|
|
||||||
|
log.info("Running T5 text encoder on prompt: %r", prompt)
|
||||||
|
import copy
|
||||||
|
input_info = copy.deepcopy(pipe._input_info_template)
|
||||||
|
input_info.prompt = prompt
|
||||||
|
|
||||||
|
runner.run_text_encoder(input_info)
|
||||||
|
log.info("T5 encode complete.")
|
||||||
|
|
||||||
|
# Inspect output — check all dataclass fields for tensor results
|
||||||
|
import torch
|
||||||
|
for attr in vars(input_info):
|
||||||
|
val = getattr(input_info, attr)
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
|
||||||
|
|
||||||
|
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
|
||||||
|
if DIT_QUANT.startswith("gguf-"):
|
||||||
|
current = GET_DTYPE()
|
||||||
|
import torch
|
||||||
|
expected = torch.float16
|
||||||
|
if current == expected:
|
||||||
|
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
|
||||||
|
else:
|
||||||
|
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
@@ -105,7 +105,8 @@ def test_models_section_override():
|
|||||||
{
|
{
|
||||||
"models": {
|
"models": {
|
||||||
"wan22_base_repo": "/local/weights/wan22",
|
"wan22_base_repo": "/local/weights/wan22",
|
||||||
"wan22_fp8_repo": "/local/weights/wan22-fp8",
|
"wan22_dit_repo": "/local/weights/wan22-dit",
|
||||||
|
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
|
||||||
"wan22_config_json": "/local/cfg/fp8.json",
|
"wan22_config_json": "/local/cfg/fp8.json",
|
||||||
"wan22_model_cls": "wan2.2_moe",
|
"wan22_model_cls": "wan2.2_moe",
|
||||||
"musetalk_path": "/local/weights/musetalk",
|
"musetalk_path": "/local/weights/musetalk",
|
||||||
@@ -113,7 +114,20 @@ def test_models_section_override():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert cfg.wan22_base_repo == "/local/weights/wan22"
|
assert cfg.wan22_base_repo == "/local/weights/wan22"
|
||||||
assert cfg.wan22_fp8_repo == "/local/weights/wan22-fp8"
|
assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
|
||||||
|
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
|
||||||
assert cfg.wan22_config_json == "/local/cfg/fp8.json"
|
assert cfg.wan22_config_json == "/local/cfg/fp8.json"
|
||||||
assert cfg.wan22_model_cls == "wan2.2_moe"
|
assert cfg.wan22_model_cls == "wan2.2_moe"
|
||||||
assert cfg.musetalk_model_path == "/local/weights/musetalk"
|
assert cfg.musetalk_model_path == "/local/weights/musetalk"
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_section_backwards_compat_fp8_repo():
|
||||||
|
"""Old config key wan22_fp8_repo still works via fallback."""
|
||||||
|
cfg = VideoConfig.from_dict(
|
||||||
|
{
|
||||||
|
"models": {
|
||||||
|
"wan22_fp8_repo": "/local/weights/wan22-fp8",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert cfg.wan22_dit_repo == "/local/weights/wan22-fp8"
|
||||||
|
|||||||
Reference in New Issue
Block a user