diff --git a/Dockerfile b/Dockerfile index 1dd8ec3..1e8df1f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,6 +53,14 @@ RUN python3.11 -m pip install --no-cache-dir \ "git+https://github.com/ModelTC/LightX2V.git" || \ echo "LightX2V install failed — config.video.enabled must stay false until fixed" # +# sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120 +# (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead. +# Our wan22.py patches fp8_scaled_mm → torch._scaled_mm at runtime for +# Blackwell GPUs, but the sgl_kernel package itself must still be present. +RUN python3.11 -m pip install --no-cache-dir --no-deps \ + "sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \ + echo "sgl-kernel install failed — fp8 T5 will fall back to bf16" +# # MuseTalk (audio-driven lip-sync) — same story. RUN python3.11 -m pip install --no-cache-dir \ "git+https://github.com/TMElyralab/MuseTalk.git" || \ diff --git a/config.yml b/config.yml index 597e5e8..ebafaa2 100644 --- a/config.yml +++ b/config.yml @@ -32,16 +32,23 @@ video: casual gestures, natural lighting, soft focus background prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint} - # Model sources for the video stack. The fp8 e4m3 4-step distilled DIT - # weights from lightx2v/Wan2.2-Distill-Models are ~15 GB each (vs ~28 GB - # bf16) — that's the "save VRAM" path. T5/VAE/tokenizer still come from - # the Wan-AI base repo. Both repos download on first run into - # HF_HOME=/cache/huggingface. + # Model sources for the video stack. T5/VAE/tokenizer come from the + # Wan-AI base repo. DIT weights come from wan22_dit_repo in the format + # specified by wan22_dit_quant_scheme. Both repos download on first run + # into HF_HOME=/cache/huggingface. + # + # Supported dit_quant_scheme values: + # fp8-sgl — fp8 e4m3 safetensors (~15 GB/expert, from lightx2v/Wan2.2-Distill-Models) + # gguf-Q4_K_M — GGUF 4-bit (~9.65 GB/expert, from QuantStack/Wan2.2-I2V-A14B-GGUF) + # gguf-Q8_0 — GGUF 8-bit (~15.4 GB/expert) + # (any gguf- supported by LightX2V — see base_model.py MM_WEIGHT_REGISTER) models: wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B - wan22_fp8_repo: lightx2v/Wan2.2-Distill-Models + wan22_dit_repo: QuantStack/Wan2.2-I2V-A14B-GGUF + wan22_dit_quant_scheme: gguf-Q4_K_M + wan22_t5_quantized: true wan22_model_cls: wan2.2_moe_distill - wan22_config_json: /app/configs/lightx2v/wan22_i2v_fp8_distill.json + wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_distill.json musetalk_path: TMElyralab/MuseTalk # LoRAs applied to the fp8 base at load time via runtime switch_lora. diff --git a/configs/lightx2v/wan22_i2v_gguf_distill.json b/configs/lightx2v/wan22_i2v_gguf_distill.json new file mode 100644 index 0000000..f8730cf --- /dev/null +++ b/configs/lightx2v/wan22_i2v_gguf_distill.json @@ -0,0 +1,35 @@ +{ + "_comment": "Wan2.2 i2v MoE 4-step distill, GGUF quantized. Uses QuantStack/Wan2.2-I2V-A14B-GGUF checkpoints instead of fp8 safetensors. GGUF does not support block-level offload so offload_granularity is set to 'model' — the entire DIT is moved to GPU when active. With Q4_K_M (~9.65 GB per expert) this fits comfortably in 24+ GB VRAM. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py. IMPORTANT: GGUF dequantizes to fp16, so you must set DTYPE=FP16 in the container environment.", + + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + + "resize_mode": "adaptive", + "resolution": "480p", + "target_height": 480, + "target_width": 480, + "fps": 16, + + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + + "sample_guide_scale": [3.5, 3.5], + "sample_shift": 5.0, + "enable_cfg": false, + + "cpu_offload": true, + "offload_granularity": "model", + "t5_cpu_offload": true, + "vae_cpu_offload": false, + + "use_image_encoder": false, + + "boundary_step_index": 2, + "denoising_step_list": [1000, 750, 500, 250], + + "dit_quantized": true, + "dit_quant_scheme": "gguf-Q4_K_M", + "t5_quantized": false +} diff --git a/docker-compose.yml b/docker-compose.yml index e10d168..f57fa6a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,7 @@ services: - ./configs:/app/configs:ro - ./server:/app/server:ro - ./static:/app/static:ro + - ./tests:/app/tests - ./run.py:/app/run.py:ro deploy: resources: diff --git a/requirements.txt b/requirements.txt index 7bdba86..8108a4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,9 @@ pyyaml imageio[ffmpeg]>=2.34 av>=12.0 pyzmq>=25.0 +gguf>=0.6.0 +# sgl-kernel: installed from SGLang's cu128 wheel index in Dockerfile +# (PyPI version lacks SM120/Blackwell CUDA kernels) # LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the # Dockerfile because neither ships a stable PyPI release yet. See lines # "LightX2V from source" / "MuseTalk from source" in Dockerfile. diff --git a/server/video.py b/server/video.py index 9a98b69..4f36bca 100644 --- a/server/video.py +++ b/server/video.py @@ -58,15 +58,18 @@ class VideoConfig: loras: list[LoRASpec] = field(default_factory=list) # Model paths — can be overridden via config.yml.video.models. - # wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer. - # The bf16 DIT shards in this repo are skipped — we - # replace them with the fp8 files from wan22_fp8_repo. - # wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3 - # 4-step distilled DIT checkpoints (~15 GB each). - # wan22_config_json: path to the LightX2V inference config template the - # Wan22Pipeline will fill in with absolute ckpt paths. + # wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer. + # The bf16 DIT shards in this repo are skipped — we + # replace them with quantised files from wan22_dit_repo. + # wan22_dit_repo : HF repo id (or local dir) providing the quantised + # DIT checkpoints (fp8 or GGUF). + # wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M". + # wan22_config_json : path to the LightX2V inference config template the + # Wan22Pipeline will fill in with absolute ckpt paths. wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B" - wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models" + wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models" + wan22_dit_quant_scheme: str = "fp8-sgl" + wan22_t5_quantized: bool = False wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" wan22_model_cls: str = "wan2.2_moe_distill" musetalk_model_path: str = "TMElyralab/MuseTalk" @@ -121,8 +124,18 @@ class VideoConfig: wan22_base_repo=str( models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B") ), - wan22_fp8_repo=str( - models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models") + wan22_dit_repo=str( + models_raw.get( + "wan22_dit_repo", + # Backwards compat: fall back to old key name. + models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"), + ) + ), + wan22_dit_quant_scheme=str( + models_raw.get("wan22_dit_quant_scheme", "fp8-sgl") + ), + wan22_t5_quantized=bool( + models_raw.get("wan22_t5_quantized", False) ), wan22_config_json=str( models_raw.get( @@ -204,16 +217,19 @@ class VideoEngine: from server.video_models.musetalk import MuseTalkEngine log.info( - "Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...", - self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo, + "Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...", + self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo, + self.cfg.wan22_dit_quant_scheme, ) self._wan22 = Wan22Pipeline( base_repo=self.cfg.wan22_base_repo, - fp8_repo=self.cfg.wan22_fp8_repo, + dit_repo=self.cfg.wan22_dit_repo, config_json=self.cfg.wan22_config_json, model_cls=self.cfg.wan22_model_cls, resolution=self.cfg.resolution, fps=self.cfg.fps, + dit_quant_scheme=self.cfg.wan22_dit_quant_scheme, + t5_quantized=self.cfg.wan22_t5_quantized, ) if self.cfg.loras: self._wan22.load_loras(self.cfg.loras) diff --git a/server/video_models/wan22.py b/server/video_models/wan22.py index 327bbae..3ecf5bd 100644 --- a/server/video_models/wan22.py +++ b/server/video_models/wan22.py @@ -1,4 +1,4 @@ -"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V. +"""Wan2.2-Lightning image-to-video wrapper via LightX2V. This wrapper targets LightX2V's actual Python entry points (verified against the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main): @@ -25,10 +25,12 @@ Two HuggingFace repos are consumed on first run (cached under HF_HOME): - Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only. The bf16 DIT shards under high_noise_model/ and low_noise_model/ are SKIPPED via - ignore_patterns — we replace them with fp8. - - lightx2v/Wan2.2-Distill-Models — exactly two safetensors files: - the fp8 e4m3 4-step distilled high/low - noise DIT checkpoints (~15 GB each). + ignore_patterns — we replace them with + quantised checkpoints from dit_repo. + - dit_repo (configurable) — quantised DIT checkpoints. Supported + formats: + * fp8 safetensors (lightx2v/Wan2.2-Distill-Models) + * GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF) """ from __future__ import annotations @@ -47,13 +49,22 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) +# --- fp8 distill filenames -------------------------------------------------- FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" +# --- GGUF filenames (QuantStack layout: HighNoise/.gguf) --------------- +GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf" +GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf" + +# --- fp8 T5 encoder (lightx2v/Encoders repo) -------------------------------- +T5_FP8_REPO = "lightx2v/Encoders" +T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors" + # The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the -# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8 -# files from the distill repo replace the DIT weights entirely. We must keep -# the config.json / index.json metadata under high_noise_model/ and +# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the +# quantised files from dit_repo replace the DIT weights entirely. We must +# keep the config.json / index.json metadata under high_noise_model/ and # low_noise_model/ (LightX2V's set_config reads architecture params like # ``dim`` from them) and the tokenizer files under google/. BASE_REPO_IGNORE_PATTERNS = [ @@ -66,8 +77,68 @@ BASE_REPO_IGNORE_PATTERNS = [ ] +def _patch_fp8_scaled_mm_for_blackwell() -> None: + """Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell. + + sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet. + PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures + including Blackwell. This patch is idempotent. + """ + try: + import sgl_kernel # type: ignore[import-not-found] + except ImportError: + return # no sgl_kernel → fp8 T5 not in use + + if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False): + return + + import torch + + if not torch.cuda.is_available(): + return + + cap = torch.cuda.get_device_capability() + if cap[0] < 12: + return # only patch on Blackwell+ + + _orig = sgl_kernel.fp8_scaled_mm + + def _torch_fp8_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # torch._scaled_mm expects (M,K) @ (N,K).t() with: + # scale_a: scalar or (M,1) + # scale_b: scalar or (1,N) + # sgl_kernel provides scale_b as (N,1) — transpose it. + if b_scale.dim() == 2 and b_scale.shape[1] == 1: + b_scale = b_scale.t() + # _scaled_mm requires B to be column-major (stride(0)==1). + bt = b.t().contiguous().t() + out = torch._scaled_mm( + a, bt, + scale_a=a_scale, scale_b=b_scale, + out_dtype=out_dtype, bias=bias, + ) + return out + + sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm + sgl_kernel._fp8_patched_for_blackwell = True + log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap) + + class Wan22Pipeline: - """Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights. + """Wrapper around LightX2V's Wan2.2 MoE distill runner. + + Supports two DIT quantisation formats: + * **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from + ``lightx2v/Wan2.2-Distill-Models``) + * **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level, + from ``QuantStack/Wan2.2-I2V-A14B-GGUF``) Constructor downloads (if needed) both HF repos, writes a runtime JSON config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``. @@ -77,23 +148,34 @@ class Wan22Pipeline: def __init__( self, base_repo: str, - fp8_repo: str, + dit_repo: str, config_json: str, model_cls: str = "wan2.2_moe_distill", resolution: int = 480, fps: int = 16, + dit_quant_scheme: str = "fp8-sgl", + t5_quantized: bool = False, ): self.base_repo = base_repo - self.fp8_repo = fp8_repo + self.dit_repo = dit_repo self.config_json_template = config_json self.model_cls = model_cls self.resolution = resolution self.fps = fps + self.dit_quant_scheme = dit_quant_scheme + self.t5_quantized = t5_quantized self._applied_loras: list[LoRASpec] = [] - # 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts. + self._is_gguf = dit_quant_scheme.startswith("gguf-") + + # 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts. self._model_root = self._ensure_base_repo(base_repo) - self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo) + self._dit_high, self._dit_low = self._ensure_dit_checkpoints( + dit_repo, dit_quant_scheme, + ) + self._t5_fp8_ckpt = ( + self._ensure_t5_fp8() if t5_quantized else None + ) # 2. Materialize a runtime JSON config with absolute ckpt paths. self._runtime_json_path = self._build_runtime_config() @@ -105,13 +187,17 @@ class Wan22Pipeline: config_json=self._runtime_json_path, ) - # 4. set_config → init_runner. Runner construction triggers weight load. - # Imports are scoped here so ``import server.video_models.wan22`` - # never pulls in lightx2v (tests can import this module on CPU). + # 4. Import LightX2V (scoped here so ``import server.video_models.wan22`` + # never pulls in lightx2v — tests can import this module on CPU). from lightx2v.utils.set_config import set_config # type: ignore[import-not-found] from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found] from lightx2v.infer import init_runner # type: ignore[import-not-found] + _patch_fp8_scaled_mm_for_blackwell() + + # 5. Load all models under default DTYPE=BF16 so T5 (which is + # hardcoded to bf16 weights) initialises its offload buffers + # correctly. We flip to FP16 *after* init_runner completes. log.info("LightX2V set_config (model_cls=%s, model_path=%s)", model_cls, self._model_root) self._config = set_config(args) @@ -124,6 +210,52 @@ class Wan22Pipeline: self._runner = init_runner(self._config) log.info("LightX2V runner loaded; weights resident.") + # 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT + # dequantises to fp16, and many intermediate tensors inside the + # DIT forward pass are allocated via GET_DTYPE(). The T5 encoder + # is wrapped to temporarily restore BF16 during its forward. + if self._is_gguf: + os.environ["DTYPE"] = "FP16" + from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found] + GET_DTYPE.cache_clear() + log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE()) + self._patch_t5_dtype_for_gguf() + + # --- GGUF dtype compatibility patch ---------------------------------------- + + def _patch_t5_dtype_for_gguf(self) -> None: + """Wrap the T5 encoder so it temporarily restores DTYPE=BF16. + + The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When + the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload + path breaks because intermediate tensor dtypes no longer match the bf16 + weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE() + back to bf16, then restore fp16 before the DIT runs. + """ + import os + import types + from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found] + + runner = self._runner + orig_run_text_encoder = runner.run_text_encoder.__func__ + + def bf16_text_encoder(self_runner, *args, **kwargs): + # Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights. + os.environ["DTYPE"] = "BF16" + GET_DTYPE.cache_clear() + GET_SENSITIVE_DTYPE.cache_clear() + try: + result = orig_run_text_encoder(self_runner, *args, **kwargs) + finally: + # Restore FP16 for the DIT / rest of the pipeline. + os.environ["DTYPE"] = "FP16" + GET_DTYPE.cache_clear() + GET_SENSITIVE_DTYPE.cache_clear() + return result + + runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner) + log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.") + # --- Weight provisioning ------------------------------------------------- @staticmethod @@ -132,7 +264,7 @@ class Wan22Pipeline: If ``base_repo`` is already a local directory, use it as-is. Otherwise snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT - shards (they're replaced by the fp8 files). + shards (they're replaced by the quantised files). """ if os.path.isdir(base_repo): return base_repo @@ -145,42 +277,75 @@ class Wan22Pipeline: ) @staticmethod - def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]: - """Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair. + def _ensure_dit_checkpoints( + dit_repo: str, + dit_quant_scheme: str, + ) -> tuple[str, str]: + """Return (high_noise_path, low_noise_path) for the DIT pair. - - If ``fp8_repo`` is a local directory, expect both files inside it. - - Otherwise treat it as a HF repo id and download only the two files - we need (not the ~150 GB of other variants in that repo). + Supports both fp8 safetensors and GGUF formats. """ - if not fp8_repo: - raise ValueError("fp8_repo must be a HF repo id or local directory.") - if os.path.isdir(fp8_repo): - high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE) - low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE) + if not dit_repo: + raise ValueError("dit_repo must be a HF repo id or local directory.") + + is_gguf = dit_quant_scheme.startswith("gguf-") + + if is_gguf: + # Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M" + quant = dit_quant_scheme.replace("gguf-", "") + high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant) + low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant) + else: + high_file = FP8_HIGH_NOISE_FILE + low_file = FP8_LOW_NOISE_FILE + + # Local directory? + if os.path.isdir(dit_repo): + high = os.path.join(dit_repo, high_file) + low = os.path.join(dit_repo, low_file) if not (os.path.isfile(high) and os.path.isfile(low)): raise FileNotFoundError( - f"fp8 checkpoints not found in {fp8_repo}: expected " - f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}" + f"DIT checkpoints not found in {dit_repo}: expected " + f"{high_file} and {low_file}" ) return high, low + + # HuggingFace download. from huggingface_hub import hf_hub_download - log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo) - high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE) - low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE) + log.info("Downloading %s DIT checkpoints from %s ...", + dit_quant_scheme, dit_repo) + high = hf_hub_download(repo_id=dit_repo, filename=high_file) + low = hf_hub_download(repo_id=dit_repo, filename=low_file) return high, low + @staticmethod + def _ensure_t5_fp8() -> str: + """Download the fp8 T5 encoder from lightx2v/Encoders (if not cached). + + Returns the local path to the safetensors file (~6 GB). + """ + from huggingface_hub import hf_hub_download + log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO) + return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE) + def _build_runtime_config(self) -> str: """Load the template JSON, inject absolute ckpt paths, persist to temp.""" with open(self.config_json_template, "r", encoding="utf-8") as f: cfg = json.load(f) # Drop editorial comments before passing to LightX2V. cfg.pop("_comment", None) - cfg["high_noise_quantized_ckpt"] = self._fp8_high - cfg["low_noise_quantized_ckpt"] = self._fp8_low + cfg["high_noise_quantized_ckpt"] = self._dit_high + cfg["low_noise_quantized_ckpt"] = self._dit_low cfg.setdefault("fps", self.fps) + # T5 fp8 quantization. + if self._t5_fp8_ckpt: + cfg["t5_quantized"] = True + cfg["t5_quant_scheme"] = "fp8-sgl" + cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt + tmp = tempfile.NamedTemporaryFile( - prefix="wan22_fp8_", suffix=".json", + prefix="wan22_dit_", suffix=".json", mode="w", delete=False, encoding="utf-8", ) json.dump(cfg, tmp, indent=2) diff --git a/tests/component/sample_avatar.png b/tests/component/sample_avatar.png new file mode 100644 index 0000000..d084468 Binary files /dev/null and b/tests/component/sample_avatar.png differ diff --git a/tests/component/sample_avatar.webp b/tests/component/sample_avatar.webp new file mode 100644 index 0000000..ae3e289 Binary files /dev/null and b/tests/component/sample_avatar.webp differ diff --git a/tests/component/test_02_wan22_loras.py b/tests/component/test_02_wan22_loras.py index f41e857..790b35c 100644 --- a/tests/component/test_02_wan22_loras.py +++ b/tests/component/test_02_wan22_loras.py @@ -1,15 +1,22 @@ -"""Phase 2 component test: Wan2.2-Lightning fp8 pipeline + LoRA stacking. +"""Phase 2 component test: Wan2.2 pipeline + LoRA stacking. Verifies: -- ``Wan22Pipeline`` loads successfully against the fp8 distill path - (exercises the real LightX2V set_config → init_runner flow). +- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V + set_config -> init_runner flow). - ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at ``/cache/loras/wan22-[HL]-e8.safetensors``. -Requires GPU and a first-run download of both HF repos (base support files -~12 GB, fp8 DIT ~30 GB). If LightX2V isn't installed the test is skipped. +Supports both fp8 and GGUF DIT quantisation. Set the ``DIT_QUANT`` +environment variable to switch (default: ``fp8-sgl``). -Run: + DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \ + python -m tests.component.test_02_wan22_loras + +Requires GPU and a first-run download of both HF repos (base support files +~12 GB, DIT size depends on quant — fp8 ~30 GB, GGUF Q4_K_M ~19 GB). +If LightX2V isn't installed the test is skipped. + +Run (default fp8): docker compose exec voice-chat python -m tests.component.test_02_wan22_loras """ from __future__ import annotations @@ -21,7 +28,17 @@ from tests.component._common import get_logger log = get_logger("test_02") -CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" +# --- Quant-dependent defaults ------------------------------------------------ + +DIT_QUANT = os.environ.get("DIT_QUANT", "fp8-sgl") + +if DIT_QUANT.startswith("gguf-"): + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json" + DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF" +else: + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + DIT_REPO = "lightx2v/Wan2.2-Distill-Models" + LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors" LORA_LOW = "/cache/loras/wan22-L-e8.safetensors" @@ -37,15 +54,16 @@ def run(): from server.video import LoRASpec log.info("[case 1] Instantiate Wan22Pipeline " - "(first run downloads ~42 GB total)...") + "(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO) try: pipe = Wan22Pipeline( base_repo="Wan-AI/Wan2.2-I2V-A14B", - fp8_repo="lightx2v/Wan2.2-Distill-Models", + dit_repo=DIT_REPO, config_json=CONFIG_JSON, model_cls="wan2.2_moe_distill", resolution=480, fps=16, + dit_quant_scheme=DIT_QUANT, ) except Exception as e: log.error("FAIL: Wan22Pipeline construction raised: %s", e) @@ -56,7 +74,7 @@ def run(): log.info(" PASS: pipeline constructed") # --- LoRAs --- - log.info("[case 2] load_loras with empty list → no-op") + log.info("[case 2] load_loras with empty list -> no-op") pipe.load_loras([]) log.info(" PASS") diff --git a/tests/component/test_09_gguf_generate.py b/tests/component/test_09_gguf_generate.py new file mode 100644 index 0000000..bb56404 --- /dev/null +++ b/tests/component/test_09_gguf_generate.py @@ -0,0 +1,83 @@ +"""Quick smoke test: generate a video clip with the GGUF pipeline. + +Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine) +and writes the result to tests/component/_out/phase9_gguf.mp4. + +Run: + docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \ + python -m tests.component.test_09_gguf_generate +""" +from __future__ import annotations + +import os +import sys + +from tests.component._common import ensure_sample_avatar, get_logger, write_bytes + +log = get_logger("test_09") + +DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M") + +if DIT_QUANT.startswith("gguf-"): + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json" + DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF" +else: + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + DIT_REPO = "lightx2v/Wan2.2-Distill-Models" + + +def run(): + try: + from server.video_models.wan22 import Wan22Pipeline + except ImportError as e: + log.error("Import failed: %s", e) + sys.exit(0) + + avatar = ensure_sample_avatar() + log.info("Avatar: %s", avatar) + + log.info("Building pipeline (quant=%s)...", DIT_QUANT) + pipe = Wan22Pipeline( + base_repo="Wan-AI/Wan2.2-I2V-A14B", + dit_repo=DIT_REPO, + config_json=CONFIG_JSON, + model_cls="wan2.2_moe_distill", + resolution=480, + fps=16, + dit_quant_scheme=DIT_QUANT, + t5_quantized=True, + ) + log.info("Pipeline ready.") + + # Debug: verify DTYPE is set correctly for GGUF + from lightx2v.utils.envs import GET_DTYPE + log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE")) + + log.info("Generating 3-second i2v clip...") + frames = pipe.generate_i2v( + image_path=avatar, + prompt="a person looking at the camera, natural lighting, soft focus background", + seconds=3, + seed=42, + ) + log.info("Got frames: shape=%s dtype=%s", frames.shape, frames.dtype) + + # Encode to MP4 + import imageio.v3 as iio + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf: + tmp = tf.name + try: + iio.imwrite(tmp, frames, fps=16, codec="libx264") + with open(tmp, "rb") as f: + mp4_bytes = f.read() + finally: + os.remove(tmp) + + out = write_bytes("phase9_gguf.mp4", mp4_bytes) + log.info("PASS: video written to %s (%d bytes, %d frames)", out, len(mp4_bytes), frames.shape[0]) + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_10_t5_encode.py b/tests/component/test_10_t5_encode.py new file mode 100644 index 0000000..bf7084a --- /dev/null +++ b/tests/component/test_10_t5_encode.py @@ -0,0 +1,88 @@ +"""Smoke test: T5 text encoding under GGUF pipeline. + +Builds the Wan22Pipeline (loads all weights including DIT) but only +exercises the T5 encoder — no image-to-video generation. Validates that +the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully. + +Run: + docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \ + python -m tests.component.test_10_t5_encode +""" +from __future__ import annotations + +import os +import sys + +from tests.component._common import get_logger + +log = get_logger("test_10") + +DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M") + +if DIT_QUANT.startswith("gguf-"): + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json" + DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF" +else: + CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + DIT_REPO = "lightx2v/Wan2.2-Distill-Models" + + +def run(): + try: + from server.video_models.wan22 import Wan22Pipeline + except ImportError as e: + log.error("Import failed: %s", e) + sys.exit(0) + + log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT) + pipe = Wan22Pipeline( + base_repo="Wan-AI/Wan2.2-I2V-A14B", + dit_repo=DIT_REPO, + config_json=CONFIG_JSON, + model_cls="wan2.2_moe_distill", + resolution=480, + fps=16, + dit_quant_scheme=DIT_QUANT, + t5_quantized=True, + ) + log.info("Pipeline ready.") + + # Check DTYPE state after init + from lightx2v.utils.envs import GET_DTYPE + log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE")) + + # Run only the T5 text encoder + runner = pipe._runner + prompt = "a person looking at the camera, natural lighting" + + log.info("Running T5 text encoder on prompt: %r", prompt) + import copy + input_info = copy.deepcopy(pipe._input_info_template) + input_info.prompt = prompt + + runner.run_text_encoder(input_info) + log.info("T5 encode complete.") + + # Inspect output — check all dataclass fields for tensor results + import torch + for attr in vars(input_info): + val = getattr(input_info, attr) + if isinstance(val, torch.Tensor): + log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device) + + # Verify DTYPE is back to FP16 after T5 runs (if GGUF) + if DIT_QUANT.startswith("gguf-"): + current = GET_DTYPE() + import torch + expected = torch.float16 + if current == expected: + log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.") + else: + log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected) + sys.exit(1) + + log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT) + + +if __name__ == "__main__": + run() diff --git a/tests/unit/test_video_config.py b/tests/unit/test_video_config.py index 54e5333..9eb7b56 100644 --- a/tests/unit/test_video_config.py +++ b/tests/unit/test_video_config.py @@ -105,7 +105,8 @@ def test_models_section_override(): { "models": { "wan22_base_repo": "/local/weights/wan22", - "wan22_fp8_repo": "/local/weights/wan22-fp8", + "wan22_dit_repo": "/local/weights/wan22-dit", + "wan22_dit_quant_scheme": "gguf-Q4_K_M", "wan22_config_json": "/local/cfg/fp8.json", "wan22_model_cls": "wan2.2_moe", "musetalk_path": "/local/weights/musetalk", @@ -113,7 +114,20 @@ def test_models_section_override(): } ) assert cfg.wan22_base_repo == "/local/weights/wan22" - assert cfg.wan22_fp8_repo == "/local/weights/wan22-fp8" + assert cfg.wan22_dit_repo == "/local/weights/wan22-dit" + assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M" assert cfg.wan22_config_json == "/local/cfg/fp8.json" assert cfg.wan22_model_cls == "wan2.2_moe" assert cfg.musetalk_model_path == "/local/weights/musetalk" + + +def test_models_section_backwards_compat_fp8_repo(): + """Old config key wan22_fp8_repo still works via fallback.""" + cfg = VideoConfig.from_dict( + { + "models": { + "wan22_fp8_repo": "/local/weights/wan22-fp8", + } + } + ) + assert cfg.wan22_dit_repo == "/local/weights/wan22-fp8"