t5 encoder fp8 seems to be working
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
@@ -1,15 +1,22 @@
|
||||
"""Phase 2 component test: Wan2.2-Lightning fp8 pipeline + LoRA stacking.
|
||||
"""Phase 2 component test: Wan2.2 pipeline + LoRA stacking.
|
||||
|
||||
Verifies:
|
||||
- ``Wan22Pipeline`` loads successfully against the fp8 distill path
|
||||
(exercises the real LightX2V set_config → init_runner flow).
|
||||
- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
|
||||
set_config -> init_runner flow).
|
||||
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
|
||||
``/cache/loras/wan22-[HL]-e8.safetensors``.
|
||||
|
||||
Requires GPU and a first-run download of both HF repos (base support files
|
||||
~12 GB, fp8 DIT ~30 GB). If LightX2V isn't installed the test is skipped.
|
||||
Supports both fp8 and GGUF DIT quantisation. Set the ``DIT_QUANT``
|
||||
environment variable to switch (default: ``fp8-sgl``).
|
||||
|
||||
Run:
|
||||
DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
|
||||
python -m tests.component.test_02_wan22_loras
|
||||
|
||||
Requires GPU and a first-run download of both HF repos (base support files
|
||||
~12 GB, DIT size depends on quant — fp8 ~30 GB, GGUF Q4_K_M ~19 GB).
|
||||
If LightX2V isn't installed the test is skipped.
|
||||
|
||||
Run (default fp8):
|
||||
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -21,7 +28,17 @@ from tests.component._common import get_logger
|
||||
|
||||
log = get_logger("test_02")
|
||||
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
# --- Quant-dependent defaults ------------------------------------------------
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "fp8-sgl")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
|
||||
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
|
||||
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
|
||||
|
||||
@@ -37,15 +54,16 @@ def run():
|
||||
from server.video import LoRASpec
|
||||
|
||||
log.info("[case 1] Instantiate Wan22Pipeline "
|
||||
"(first run downloads ~42 GB total)...")
|
||||
"(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
|
||||
try:
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
fp8_repo="lightx2v/Wan2.2-Distill-Models",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
|
||||
@@ -56,7 +74,7 @@ def run():
|
||||
log.info(" PASS: pipeline constructed")
|
||||
|
||||
# --- LoRAs ---
|
||||
log.info("[case 2] load_loras with empty list → no-op")
|
||||
log.info("[case 2] load_loras with empty list -> no-op")
|
||||
pipe.load_loras([])
|
||||
log.info(" PASS")
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Quick smoke test: generate a video clip with the GGUF pipeline.
|
||||
|
||||
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
|
||||
and writes the result to tests/component/_out/phase9_gguf.mp4.
|
||||
|
||||
Run:
|
||||
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||
python -m tests.component.test_09_gguf_generate
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
|
||||
|
||||
log = get_logger("test_09")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
except ImportError as e:
|
||||
log.error("Import failed: %s", e)
|
||||
sys.exit(0)
|
||||
|
||||
avatar = ensure_sample_avatar()
|
||||
log.info("Avatar: %s", avatar)
|
||||
|
||||
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
t5_quantized=True,
|
||||
)
|
||||
log.info("Pipeline ready.")
|
||||
|
||||
# Debug: verify DTYPE is set correctly for GGUF
|
||||
from lightx2v.utils.envs import GET_DTYPE
|
||||
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||
|
||||
log.info("Generating 3-second i2v clip...")
|
||||
frames = pipe.generate_i2v(
|
||||
image_path=avatar,
|
||||
prompt="a person looking at the camera, natural lighting, soft focus background",
|
||||
seconds=3,
|
||||
seed=42,
|
||||
)
|
||||
log.info("Got frames: shape=%s dtype=%s", frames.shape, frames.dtype)
|
||||
|
||||
# Encode to MP4
|
||||
import imageio.v3 as iio
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
|
||||
tmp = tf.name
|
||||
try:
|
||||
iio.imwrite(tmp, frames, fps=16, codec="libx264")
|
||||
with open(tmp, "rb") as f:
|
||||
mp4_bytes = f.read()
|
||||
finally:
|
||||
os.remove(tmp)
|
||||
|
||||
out = write_bytes("phase9_gguf.mp4", mp4_bytes)
|
||||
log.info("PASS: video written to %s (%d bytes, %d frames)", out, len(mp4_bytes), frames.shape[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Smoke test: T5 text encoding under GGUF pipeline.
|
||||
|
||||
Builds the Wan22Pipeline (loads all weights including DIT) but only
|
||||
exercises the T5 encoder — no image-to-video generation. Validates that
|
||||
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
|
||||
|
||||
Run:
|
||||
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||
python -m tests.component.test_10_t5_encode
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tests.component._common import get_logger
|
||||
|
||||
log = get_logger("test_10")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
except ImportError as e:
|
||||
log.error("Import failed: %s", e)
|
||||
sys.exit(0)
|
||||
|
||||
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
t5_quantized=True,
|
||||
)
|
||||
log.info("Pipeline ready.")
|
||||
|
||||
# Check DTYPE state after init
|
||||
from lightx2v.utils.envs import GET_DTYPE
|
||||
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||
|
||||
# Run only the T5 text encoder
|
||||
runner = pipe._runner
|
||||
prompt = "a person looking at the camera, natural lighting"
|
||||
|
||||
log.info("Running T5 text encoder on prompt: %r", prompt)
|
||||
import copy
|
||||
input_info = copy.deepcopy(pipe._input_info_template)
|
||||
input_info.prompt = prompt
|
||||
|
||||
runner.run_text_encoder(input_info)
|
||||
log.info("T5 encode complete.")
|
||||
|
||||
# Inspect output — check all dataclass fields for tensor results
|
||||
import torch
|
||||
for attr in vars(input_info):
|
||||
val = getattr(input_info, attr)
|
||||
if isinstance(val, torch.Tensor):
|
||||
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
|
||||
|
||||
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
current = GET_DTYPE()
|
||||
import torch
|
||||
expected = torch.float16
|
||||
if current == expected:
|
||||
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
|
||||
else:
|
||||
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
|
||||
sys.exit(1)
|
||||
|
||||
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -105,7 +105,8 @@ def test_models_section_override():
|
||||
{
|
||||
"models": {
|
||||
"wan22_base_repo": "/local/weights/wan22",
|
||||
"wan22_fp8_repo": "/local/weights/wan22-fp8",
|
||||
"wan22_dit_repo": "/local/weights/wan22-dit",
|
||||
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
|
||||
"wan22_config_json": "/local/cfg/fp8.json",
|
||||
"wan22_model_cls": "wan2.2_moe",
|
||||
"musetalk_path": "/local/weights/musetalk",
|
||||
@@ -113,7 +114,20 @@ def test_models_section_override():
|
||||
}
|
||||
)
|
||||
assert cfg.wan22_base_repo == "/local/weights/wan22"
|
||||
assert cfg.wan22_fp8_repo == "/local/weights/wan22-fp8"
|
||||
assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
|
||||
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
|
||||
assert cfg.wan22_config_json == "/local/cfg/fp8.json"
|
||||
assert cfg.wan22_model_cls == "wan2.2_moe"
|
||||
assert cfg.musetalk_model_path == "/local/weights/musetalk"
|
||||
|
||||
|
||||
def test_models_section_backwards_compat_fp8_repo():
|
||||
"""Old config key wan22_fp8_repo still works via fallback."""
|
||||
cfg = VideoConfig.from_dict(
|
||||
{
|
||||
"models": {
|
||||
"wan22_fp8_repo": "/local/weights/wan22-fp8",
|
||||
}
|
||||
}
|
||||
)
|
||||
assert cfg.wan22_dit_repo == "/local/weights/wan22-fp8"
|
||||
|
||||
Reference in New Issue
Block a user