t5 encoder fp8 seems to be working
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
"""Smoke test: T5 text encoding under GGUF pipeline.
|
||||
|
||||
Builds the Wan22Pipeline (loads all weights including DIT) but only
|
||||
exercises the T5 encoder — no image-to-video generation. Validates that
|
||||
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
|
||||
|
||||
Run:
|
||||
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
||||
python -m tests.component.test_10_t5_encode
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tests.component._common import get_logger
|
||||
|
||||
log = get_logger("test_10")
|
||||
|
||||
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
||||
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
||||
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
||||
else:
|
||||
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
||||
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
from server.video_models.wan22 import Wan22Pipeline
|
||||
except ImportError as e:
|
||||
log.error("Import failed: %s", e)
|
||||
sys.exit(0)
|
||||
|
||||
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
|
||||
pipe = Wan22Pipeline(
|
||||
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
||||
dit_repo=DIT_REPO,
|
||||
config_json=CONFIG_JSON,
|
||||
model_cls="wan2.2_moe_distill",
|
||||
resolution=480,
|
||||
fps=16,
|
||||
dit_quant_scheme=DIT_QUANT,
|
||||
t5_quantized=True,
|
||||
)
|
||||
log.info("Pipeline ready.")
|
||||
|
||||
# Check DTYPE state after init
|
||||
from lightx2v.utils.envs import GET_DTYPE
|
||||
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
||||
|
||||
# Run only the T5 text encoder
|
||||
runner = pipe._runner
|
||||
prompt = "a person looking at the camera, natural lighting"
|
||||
|
||||
log.info("Running T5 text encoder on prompt: %r", prompt)
|
||||
import copy
|
||||
input_info = copy.deepcopy(pipe._input_info_template)
|
||||
input_info.prompt = prompt
|
||||
|
||||
runner.run_text_encoder(input_info)
|
||||
log.info("T5 encode complete.")
|
||||
|
||||
# Inspect output — check all dataclass fields for tensor results
|
||||
import torch
|
||||
for attr in vars(input_info):
|
||||
val = getattr(input_info, attr)
|
||||
if isinstance(val, torch.Tensor):
|
||||
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
|
||||
|
||||
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
|
||||
if DIT_QUANT.startswith("gguf-"):
|
||||
current = GET_DTYPE()
|
||||
import torch
|
||||
expected = torch.float16
|
||||
if current == expected:
|
||||
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
|
||||
else:
|
||||
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
|
||||
sys.exit(1)
|
||||
|
||||
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user