Files
live-voice-chat/server/video_models/wan22.py
T
2026-04-12 16:38:44 -04:00

686 lines
27 KiB
Python

"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
config = set_config(args)
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — done ONCE
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
# LoRA hot-swap:
runner.switch_lora(lora_path, strength) # swap in
runner.switch_lora("", 0.0) # remove
Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with
quantised checkpoints from dit_repo.
- dit_repo (configurable) — quantised DIT checkpoints. Supported
formats:
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import tempfile
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from server.video import LoRASpec
log = logging.getLogger(__name__)
# --- fp8 distill filenames --------------------------------------------------
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
# quantised files from dit_repo replace the DIT weights entirely. We must
# keep the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/.
BASE_REPO_IGNORE_PATTERNS = [
"high_noise_model/*.safetensors",
"low_noise_model/*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
"*.md",
]
def _patch_fp8_scaled_mm_for_blackwell() -> None:
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
including Blackwell. This patch is idempotent.
"""
try:
import sgl_kernel # type: ignore[import-not-found]
except ImportError:
return # no sgl_kernel → fp8 T5 not in use
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
return
import torch
if not torch.cuda.is_available():
return
cap = torch.cuda.get_device_capability()
if cap[0] < 12:
return # only patch on Blackwell+
_orig = sgl_kernel.fp8_scaled_mm
def _torch_fp8_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
# scale_a: scalar or (M,1)
# scale_b: scalar or (1,N)
# sgl_kernel provides scale_b as (N,1) — transpose it.
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
b_scale = b_scale.t()
# _scaled_mm requires B to be column-major (stride(0)==1).
bt = b.t().contiguous().t()
out = torch._scaled_mm(
a, bt,
scale_a=a_scale, scale_b=b_scale,
out_dtype=out_dtype, bias=bias,
)
return out
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
sgl_kernel._fp8_patched_for_blackwell = True
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
Supports two DIT quantisation formats:
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
``lightx2v/Wan2.2-Distill-Models``)
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
``generate_i2v`` runs one inference turn against the already-loaded runner.
"""
def __init__(
self,
base_repo: str,
dit_repo: str,
config_json: str,
model_cls: str = "wan2.2_moe_distill",
resolution: int = 480,
fps: int = 16,
dit_quant_scheme: str = "fp8-sgl",
t5_quantized: bool = False,
):
self.base_repo = base_repo
self.dit_repo = dit_repo
self.config_json_template = config_json
self.model_cls = model_cls
self.resolution = resolution
self.fps = fps
self.dit_quant_scheme = dit_quant_scheme
self.t5_quantized = t5_quantized
self._applied_loras: list[LoRASpec] = []
self._is_gguf = dit_quant_scheme.startswith("gguf-")
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
self._model_root = self._ensure_base_repo(base_repo)
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
dit_repo, dit_quant_scheme,
)
self._t5_fp8_ckpt = (
self._ensure_t5_fp8() if t5_quantized else None
)
# 2. Materialize a runtime JSON config with absolute ckpt paths.
self._runtime_json_path = self._build_runtime_config()
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
args = self._build_args(
model_cls=model_cls,
model_path=self._model_root,
config_json=self._runtime_json_path,
)
# 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
# never pulls in lightx2v — tests can import this module on CPU).
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
from lightx2v.infer import init_runner # type: ignore[import-not-found]
_patch_fp8_scaled_mm_for_blackwell()
# 5. Load all models under default DTYPE=BF16 so T5 (which is
# hardcoded to bf16 weights) initialises its offload buffers
# correctly. We flip to FP16 *after* init_runner completes.
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
model_cls, self._model_root)
self._config = set_config(args)
self._input_info_template = init_empty_input_info(
args.task, args.support_tasks
)
log.info("LightX2V init_runner — loading weights (this takes a while)...")
self._runner = init_runner(self._config)
log.info("LightX2V runner loaded; weights resident.")
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
# dequantises to fp16, and many intermediate tensors inside the
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
# is wrapped to temporarily restore BF16 during its forward.
if self._is_gguf:
os.environ["DTYPE"] = "FP16"
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
GET_DTYPE.cache_clear()
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
self._patch_t5_dtype_for_gguf()
self._patch_vae_dtype_for_gguf()
self._patch_dit_fp32_weights_for_gguf()
# --- GGUF dtype compatibility patch ----------------------------------------
def _patch_t5_dtype_for_gguf(self) -> None:
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
path breaks because intermediate tensor dtypes no longer match the bf16
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
back to bf16, then restore fp16 before the DIT runs.
"""
import os
import types
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
runner = self._runner
orig_run_text_encoder = runner.run_text_encoder.__func__
def bf16_text_encoder(self_runner, *args, **kwargs):
import torch
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
os.environ["DTYPE"] = "BF16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
try:
result = orig_run_text_encoder(self_runner, *args, **kwargs)
finally:
# Restore FP16 for the DIT / rest of the pipeline.
os.environ["DTYPE"] = "FP16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
def _to_fp16(x):
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
return x.to(torch.float16)
if isinstance(x, list):
return [_to_fp16(v) for v in x]
if isinstance(x, tuple):
return tuple(_to_fp16(v) for v in x)
if isinstance(x, dict):
return {k: _to_fp16(v) for k, v in x.items()}
return x
return _to_fp16(result)
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
def _patch_vae_dtype_for_gguf(self) -> None:
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
VAE is a plain float model (not quantized), simply converting its
weights to fp16 avoids both input-vs-weight mismatches and the need
for any runtime dtype juggling.
"""
import torch
runner = self._runner
for name in ("vae_encoder", "vae_decoder"):
mod = getattr(runner, name, None)
if mod is None:
continue
inner = getattr(mod, "model", mod)
if hasattr(inner, "to"):
inner.to(dtype=torch.float16)
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
# the first op upcasts fp16 latents back to fp32/bf16.
for attr in ("mean", "inv_std"):
t = getattr(mod, attr, None)
if isinstance(t, torch.Tensor):
setattr(mod, attr, t.to(torch.float16))
scale = getattr(mod, "scale", None)
if isinstance(scale, list):
mod.scale = [
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
for t in scale
]
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
def _patch_dit_fp32_weights_for_gguf(self) -> None:
"""Cast leftover fp32 DIT pre/post weights to fp16.
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
up loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
runs uniformly in fp16.
"""
import torch
runner = self._runner
models = getattr(runner.model, "model", None)
if models is None:
return
if not isinstance(models, (list, tuple)):
models = [models]
n_cast = 0
for m in models:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
# Preserve pinned-memory status on pin_* tensors so
# move_attr_to_cuda's non-blocking H2D copy is safe.
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
# --- Weight provisioning -------------------------------------------------
@staticmethod
def _ensure_base_repo(base_repo: str) -> str:
"""Return a local directory containing the Wan2.2 base support files.
If ``base_repo`` is already a local directory, use it as-is. Otherwise
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
shards (they're replaced by the quantised files).
"""
if os.path.isdir(base_repo):
return base_repo
from huggingface_hub import snapshot_download
log.info("Downloading Wan2.2 base support files from %s "
"(skipping bf16 DIT shards)...", base_repo)
return snapshot_download(
repo_id=base_repo,
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
)
@staticmethod
def _ensure_dit_checkpoints(
dit_repo: str,
dit_quant_scheme: str,
) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the DIT pair.
Supports both fp8 safetensors and GGUF formats.
"""
if not dit_repo:
raise ValueError("dit_repo must be a HF repo id or local directory.")
is_gguf = dit_quant_scheme.startswith("gguf-")
if is_gguf:
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
quant = dit_quant_scheme.replace("gguf-", "")
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
else:
high_file = FP8_HIGH_NOISE_FILE
low_file = FP8_LOW_NOISE_FILE
# Local directory?
if os.path.isdir(dit_repo):
high = os.path.join(dit_repo, high_file)
low = os.path.join(dit_repo, low_file)
if not (os.path.isfile(high) and os.path.isfile(low)):
raise FileNotFoundError(
f"DIT checkpoints not found in {dit_repo}: expected "
f"{high_file} and {low_file}"
)
return high, low
# HuggingFace download.
from huggingface_hub import hf_hub_download
log.info("Downloading %s DIT checkpoints from %s ...",
dit_quant_scheme, dit_repo)
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
return high, low
@staticmethod
def _ensure_t5_fp8() -> str:
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
Returns the local path to the safetensors file (~6 GB).
"""
from huggingface_hub import hf_hub_download
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
def _build_runtime_config(self) -> str:
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
with open(self.config_json_template, "r", encoding="utf-8") as f:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._dit_high
cfg["low_noise_quantized_ckpt"] = self._dit_low
cfg.setdefault("fps", self.fps)
# T5 fp8 quantization.
if self._t5_fp8_ckpt:
cfg["t5_quantized"] = True
cfg["t5_quant_scheme"] = "fp8-sgl"
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_dit_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
log.info("Runtime LightX2V config: %s", tmp.name)
return tmp.name
@staticmethod
def _build_args(
*, model_cls: str, model_path: str, config_json: str
) -> argparse.Namespace:
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
``set_config`` finds the attributes it expects. We only customize the
model/task/path fields; everything else stays at the CLI defaults.
"""
return argparse.Namespace(
seed=42,
model_cls=model_cls,
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
to route the LoRA to the correct expert.
With ``lazy_load`` the DIT models are ``None`` at this point, so
runtime ``switch_lora`` is impossible. Instead we inject
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
the LoRAs are applied when the models materialise on first inference.
Without ``lazy_load`` (models already resident) we call
``switch_lora`` with explicit high/low keyword args.
"""
if not specs:
return
# Resolve every path up-front (may trigger HF download).
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
spec.name or spec.path, spec.weight, spec.target,
local_path)
resolved.append((spec, local_path))
lazy = self._config.get("lazy_load", False)
if lazy:
# Build the lora_configs list that LightX2V's lazy-load path
# reads inside MultiDistillModelStruct.infer().
lora_cfgs = []
for spec, local_path in resolved:
# LightX2V expects name "high_noise_model" / "low_noise_model"
cfg_name = {
"high_noise": "high_noise_model",
"low_noise": "low_noise_model",
}.get(spec.target)
if cfg_name is None:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
lora_cfgs.append({
"name": cfg_name,
"path": local_path,
"strength": spec.weight,
})
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
else:
# Models are loaded — use runtime hot-swap.
high_path = high_strength = None
low_path = low_strength = None
for spec, local_path in resolved:
if spec.target == "high_noise":
high_path, high_strength = local_path, spec.weight
elif spec.target == "low_noise":
low_path, low_strength = local_path, spec.weight
else:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
kwargs: dict = {}
if high_path is not None:
kwargs["high_lora_path"] = high_path
kwargs["high_lora_strength"] = high_strength
if low_path is not None:
kwargs["low_lora_path"] = low_path
kwargs["low_lora_strength"] = low_strength
ok = self._runner.switch_lora(**kwargs)
if not ok:
raise RuntimeError(
"runner.switch_lora returned False. Check that your "
"LightX2V build supports runtime LoRA updates for "
f"{self.model_cls}.")
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
lazy = self._config.get("lazy_load", False)
if lazy:
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
# If models were materialised, drop them so the next inference
# recreates them without LoRAs.
model_struct = getattr(self._runner, "model", None)
if model_struct is not None and hasattr(model_struct, "model"):
for i in range(len(model_struct.model)):
model_struct.model[i] = None
else:
self._runner.switch_lora("", 0.0)
self._applied_loras = []
@staticmethod
def _resolve_lora_path(path: str) -> str:
"""Resolve a LoRA path. Supports:
- Absolute/relative local paths (returned as-is if the file exists)
- ``repo_id:filename`` HuggingFace references
"""
if os.path.isfile(path):
return path
if ":" in path and not path.startswith(("/", "./")):
repo_id, filename = path.split(":", 1)
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename)
return path
# --- Inference ---------------------------------------------------------
def generate_i2v(
self,
image_path: str,
prompt: str,
seconds: int,
seed: int | None = None,
negative_prompt: str = "",
) -> np.ndarray:
"""Run image-to-video inference and return decoded frames.
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
"""
if seed is None:
seed = random.randint(0, 2**31 - 1)
# Wan2.2 target_video_length is "frames including the conditioning
# frame", so N seconds → N*fps + 1.
target_frames = seconds * self.fps + 1
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
try:
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d%s",
prompt[:80], seconds, seed, out_path)
update_input_info_from_dict(
self._input_info_template,
{
"seed": seed,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_path": image_path,
"save_result_path": out_path,
"target_video_length": target_frames,
"return_result_tensor": False,
},
)
self._runner.run_pipeline(self._input_info_template)
return _read_mp4_to_frames(out_path)
finally:
try:
os.remove(out_path)
except OSError:
pass
# --- MP4 decoding helper ------------------------------------------------------
def _read_mp4_to_frames(path: str) -> np.ndarray:
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
try:
import imageio.v3 as iio # type: ignore[import-not-found]
frames = iio.imread(path, plugin="pyav")
arr = np.asarray(frames)
if arr.ndim == 3:
arr = arr[None, ...]
return arr.astype(np.uint8)
except Exception as e: # pragma: no cover - fallback path
log.warning("imageio decode failed (%s); falling back to cv2", e)
import cv2 # type: ignore[import-not-found]
cap = cv2.VideoCapture(path)
frames: list[np.ndarray] = []
while True:
ok, frame = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if not frames:
raise RuntimeError(f"Failed to decode any frames from {path}")
return np.stack(frames, axis=0).astype(np.uint8)