t5 encoder fp8 seems to be working

This commit is contained in:
2026-04-12 13:50:34 -04:00
parent 2818b41004
commit fcf0be38bc
13 changed files with 505 additions and 67 deletions
+29 -13
View File
@@ -58,15 +58,18 @@ class VideoConfig:
loras: list[LoRASpec] = field(default_factory=list)
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with the fp8 files from wan22_fp8_repo.
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
# 4-step distilled DIT checkpoints (~15 GB each).
# wan22_config_json: path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with quantised files from wan22_dit_repo.
# wan22_dit_repo : HF repo id (or local dir) providing the quantised
# DIT checkpoints (fp8 or GGUF).
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
# wan22_config_json : path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_dit_quant_scheme: str = "fp8-sgl"
wan22_t5_quantized: bool = False
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
wan22_model_cls: str = "wan2.2_moe_distill"
musetalk_model_path: str = "TMElyralab/MuseTalk"
@@ -121,8 +124,18 @@ class VideoConfig:
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
),
wan22_fp8_repo=str(
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
wan22_dit_repo=str(
models_raw.get(
"wan22_dit_repo",
# Backwards compat: fall back to old key name.
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
)
),
wan22_dit_quant_scheme=str(
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
),
wan22_t5_quantized=bool(
models_raw.get("wan22_t5_quantized", False)
),
wan22_config_json=str(
models_raw.get(
@@ -204,16 +217,19 @@ class VideoEngine:
from server.video_models.musetalk import MuseTalkEngine
log.info(
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
"Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
self.cfg.wan22_dit_quant_scheme,
)
self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo,
fp8_repo=self.cfg.wan22_fp8_repo,
dit_repo=self.cfg.wan22_dit_repo,
config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution,
fps=self.cfg.fps,
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
t5_quantized=self.cfg.wan22_t5_quantized,
)
if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras)