first stab at adding video
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
"""Unit tests for the frame-length fitting helper in server.video_models.musetalk.
|
||||
|
||||
Pure-python: does not import MuseTalk itself.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from server.video_models.musetalk import _fit_frames_to_length, _ensure_uint8_rgb
|
||||
|
||||
|
||||
def _make_frames(t, h=2, w=2):
|
||||
return np.arange(t * h * w * 3, dtype=np.uint8).reshape(t, h, w, 3)
|
||||
|
||||
|
||||
def test_fit_frames_trim():
|
||||
frames = _make_frames(10)
|
||||
out = _fit_frames_to_length(frames, 4)
|
||||
assert out.shape == (4, 2, 2, 3)
|
||||
np.testing.assert_array_equal(out, frames[:4])
|
||||
|
||||
|
||||
def test_fit_frames_passthrough_when_equal():
|
||||
frames = _make_frames(5)
|
||||
out = _fit_frames_to_length(frames, 5)
|
||||
assert out is frames or np.array_equal(out, frames)
|
||||
|
||||
|
||||
def test_fit_frames_extends_with_pingpong():
|
||||
frames = _make_frames(3)
|
||||
out = _fit_frames_to_length(frames, 8)
|
||||
assert out.shape == (8, 2, 2, 3)
|
||||
# First 3 frames match the original
|
||||
np.testing.assert_array_equal(out[:3], frames)
|
||||
# Next 3 are the reverse (ping-pong)
|
||||
np.testing.assert_array_equal(out[3:6], frames[::-1])
|
||||
# Then forward again
|
||||
np.testing.assert_array_equal(out[6:8], frames[:2])
|
||||
|
||||
|
||||
def test_fit_frames_zero_target_returns_original():
|
||||
frames = _make_frames(3)
|
||||
out = _fit_frames_to_length(frames, 0)
|
||||
np.testing.assert_array_equal(out, frames)
|
||||
|
||||
|
||||
def test_ensure_uint8_rgb_from_float():
|
||||
arr = np.ones((5, 2, 2, 3), dtype=np.float32) * 0.5
|
||||
out = _ensure_uint8_rgb(arr)
|
||||
assert out.dtype == np.uint8
|
||||
assert out.shape == (5, 2, 2, 3)
|
||||
assert out[0, 0, 0, 0] == 127
|
||||
|
||||
|
||||
def test_ensure_uint8_rgb_promotes_3d_to_4d():
|
||||
arr = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||
out = _ensure_uint8_rgb(arr)
|
||||
assert out.shape == (1, 2, 2, 3)
|
||||
|
||||
|
||||
def test_ensure_uint8_rgb_clips_float_out_of_range():
|
||||
arr = np.ones((1, 1, 1, 3), dtype=np.float32) * 2.0 # 2.0 → clipped to 255
|
||||
out = _ensure_uint8_rgb(arr)
|
||||
assert out[0, 0, 0, 0] == 255
|
||||
arr2 = np.ones((1, 1, 1, 3), dtype=np.float32) * -1.0
|
||||
out2 = _ensure_uint8_rgb(arr2)
|
||||
assert out2[0, 0, 0, 0] == 0
|
||||
Reference in New Issue
Block a user