89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
"""Smoke test: T5 text encoding under GGUF pipeline.
|
|
|
|
Builds the Wan22Pipeline (loads all weights including DIT) but only
|
|
exercises the T5 encoder — no image-to-video generation. Validates that
|
|
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
|
|
|
|
Run:
|
|
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
|
|
python -m tests.component.test_10_t5_encode
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
|
|
from tests.component._common import get_logger
|
|
|
|
log = get_logger("test_10")
|
|
|
|
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
|
|
|
|
if DIT_QUANT.startswith("gguf-"):
|
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
|
|
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
|
|
else:
|
|
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
|
|
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
|
|
|
|
|
|
def run():
|
|
try:
|
|
from server.video_models.wan22 import Wan22Pipeline
|
|
except ImportError as e:
|
|
log.error("Import failed: %s", e)
|
|
sys.exit(0)
|
|
|
|
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
|
|
pipe = Wan22Pipeline(
|
|
base_repo="Wan-AI/Wan2.2-I2V-A14B",
|
|
dit_repo=DIT_REPO,
|
|
config_json=CONFIG_JSON,
|
|
model_cls="wan2.2_moe_distill",
|
|
resolution=480,
|
|
fps=16,
|
|
dit_quant_scheme=DIT_QUANT,
|
|
t5_quantized=True,
|
|
)
|
|
log.info("Pipeline ready.")
|
|
|
|
# Check DTYPE state after init
|
|
from lightx2v.utils.envs import GET_DTYPE
|
|
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
|
|
|
|
# Run only the T5 text encoder
|
|
runner = pipe._runner
|
|
prompt = "a person looking at the camera, natural lighting"
|
|
|
|
log.info("Running T5 text encoder on prompt: %r", prompt)
|
|
import copy
|
|
input_info = copy.deepcopy(pipe._input_info_template)
|
|
input_info.prompt = prompt
|
|
|
|
runner.run_text_encoder(input_info)
|
|
log.info("T5 encode complete.")
|
|
|
|
# Inspect output — check all dataclass fields for tensor results
|
|
import torch
|
|
for attr in vars(input_info):
|
|
val = getattr(input_info, attr)
|
|
if isinstance(val, torch.Tensor):
|
|
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
|
|
|
|
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
|
|
if DIT_QUANT.startswith("gguf-"):
|
|
current = GET_DTYPE()
|
|
import torch
|
|
expected = torch.float16
|
|
if current == expected:
|
|
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
|
|
else:
|
|
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
|
|
sys.exit(1)
|
|
|
|
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|