Files
live-voice-chat/tests/component/test_10_t5_encode.py
2026-04-16 10:00:37 -04:00

84 lines
2.6 KiB
Python

"""Smoke test: T5 text encoding under GGUF pipeline.
Builds the Wan22Pipeline (loads all weights including DIT) but only
exercises the T5 encoder — no image-to-video generation. Validates that
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_10_t5_encode
"""
from __future__ import annotations
import os
import sys
from tests.component._common import get_logger
log = get_logger("test_10")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
# Check DTYPE state after init
from lightx2v.utils.envs import GET_DTYPE
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
# Run only the T5 text encoder
runner = pipe._runner
prompt = "a person looking at the camera, natural lighting"
log.info("Running T5 text encoder on prompt: %r", prompt)
import copy
input_info = copy.deepcopy(pipe._input_info_template)
input_info.prompt = prompt
runner.run_text_encoder(input_info)
log.info("T5 encode complete.")
# Inspect output — check all dataclass fields for tensor results
import torch
for attr in vars(input_info):
val = getattr(input_info, attr)
if isinstance(val, torch.Tensor):
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
if DIT_QUANT.startswith("gguf-"):
current = GET_DTYPE()
import torch
expected = torch.float16
if current == expected:
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
else:
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
sys.exit(1)
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()