This commit is contained in:
aidenyzhang
2025-03-28 16:03:02 +08:00
parent 058f7ddc7f
commit db204311a5
46 changed files with 729 additions and 204 deletions
+5 -1
View File
@@ -31,12 +31,16 @@ class UNet():
unet_config,
model_path,
use_float16=False,
device=None
):
with open(unet_config, 'r') as f:
unet_config = json.load(f)
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device != None:
self.device = device
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16:
Regular → Executable
View File
+99
View File
@@ -0,0 +1,99 @@
import os
import math
import librosa
import numpy as np
import torch
from einops import rearrange
from transformers import AutoFeatureExtractor
class AudioProcessor:
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def get_audio_feature(self, wav_path, start_index=0):
if not os.path.exists(wav_path):
return None
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
assert sampling_rate == 16000
# Split audio into 30s segments
segment_length = 30 * sampling_rate
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
features = []
for segment in segments:
audio_feature = self.feature_extractor(
segment,
return_tensors="pt",
sampling_rate=sampling_rate
).input_features
features.append(audio_feature)
return features, len(librosa_output)
def get_whisper_chunk(
self,
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=25,
audio_padding_length_left=2,
audio_padding_length_right=2,
):
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
whisper_feature = []
# Process multiple 30s mel input features
for input_feature in whisper_input_features:
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1)
# Trim the last segment to remove padding
sr = 16000
audio_fps = 50
fps = int(fps)
whisper_idx_multiplier = audio_fps / fps
num_frames = math.floor((librosa_length / sr)) * fps
actual_length = math.floor((librosa_length / sr)) * audio_fps
whisper_feature = whisper_feature[:,:actual_length,...]
# Calculate padding amount
padding_nums = math.floor(whisper_idx_multiplier)
# Add padding at start and end
whisper_feature = torch.cat([
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
whisper_feature,
# Add extra padding to prevent out of bounds
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
], 1)
audio_prompts = []
for frame_index in range(num_frames):
try:
audio_index = math.floor(frame_index * whisper_idx_multiplier)
audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
assert audio_clip.shape[1] == audio_feature_length_per_frame
audio_prompts.append(audio_clip)
except Exception as e:
print(f"Error occurred: {e}")
print(f"whisper_feature.shape: {whisper_feature.shape}")
print(f"audio_clip.shape: {audio_clip.shape}")
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
exit()
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
return audio_prompts
if __name__ == "__main__":
audio_processor = AudioProcessor()
wav_path = "/cfs-workspace/users/gozhong/codes/musetalk_opensource2/data/audio/2.wav"
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
print("Audio Feature shape:", audio_feature.shape)
print("librosa_feature_length:", librosa_feature_length)
Regular → Executable
+79 -44
View File
@@ -2,9 +2,6 @@ from PIL import Image
import numpy as np
import cv2
import copy
from face_parsing import FaceParsing
fp = FaceParsing()
def get_crop_box(box, expand):
x, y, x1, y1 = box
@@ -14,46 +11,98 @@ def get_crop_box(box, expand):
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
def face_seg(image, mode="jaw", fp=None):
"""
对图像进行面部解析,生成面部区域的掩码。
Args:
image (PIL.Image): 输入图像。
Returns:
PIL.Image: 面部区域的掩码图像。
"""
seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
if seg_image is None:
print("error, no person_segment")
print("error, no person_segment") # 如果没有检测到面部,返回错误
return None
seg_image = seg_image.resize(image.size)
seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
return seg_image
def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
#print(image.shape)
#print(face.shape)
def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
"""
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
Args:
image (numpy.ndarray): 原始图像(身体部分)。
face (numpy.ndarray): 裁剪的面部图像。
face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
upper_boundary_ratio (float): 用于控制面部区域的保留比例。
expand (float): 扩展因子,用于放大裁剪框。
mode: 融合mask构建方式
Returns:
numpy.ndarray: 处理后的图像。
"""
# 将 numpy 数组转换为 PIL 图像
body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
face = Image.fromarray(face[:, :, ::-1]) # 面部图像
x, y, x1, y1 = face_box # 获取面部边界框的坐标
crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
face_position = (x, y) # 面部在原始图像中的位置
# 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
face_large = body.crop(crop_box)
ori_shape = face_large.size # 裁剪后图像的原始尺寸
# 对裁剪后的面部区域进行面部解析,生成掩码
mask_image = face_seg(face_large, mode=mode, fp=fp)
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
# 保留面部区域的上半部分(用于控制说话区域)
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
# 对掩码进行高斯模糊,使边缘更平滑
blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
#mask_array = np.array(modified_mask_image)
mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
# 将裁剪的面部图像粘贴回扩展后的面部区域
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
body.paste(face_large, crop_box[:2], mask_image)
# 不用掩码,完全用infer
#face_large.save("debug/checkpoint_6_face_large.png")
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_position = (x, y)
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
mask_image = Image.fromarray(mask_array)
mask_image = mask_image.convert("L")
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image)
body = np.array(body)
@@ -84,17 +133,3 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = image
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e])
face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face
mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY)
mask_image = (mask_image/255).astype(np.float32)
body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image)
return body
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
+65 -4
View File
@@ -8,9 +8,53 @@ from .model import BiSeNet
import torchvision.transforms as transforms
class FaceParsing():
def __init__(self):
def __init__(self, left_cheek_width=80, right_cheek_width=80):
self.net = self.model_init()
self.preprocess = self.image_preprocess()
# Ensure all size parameters are integers
cone_height = 21
tail_height = 12
total_size = cone_height + tail_height
# Create kernel with explicit integer dimensions
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
center_x = total_size // 2 # Ensure center coordinates are integers
# Cone part
for row in range(cone_height):
if row < cone_height//2:
continue
width = int(2 * (row - cone_height//2) + 1)
start = int(center_x - (width // 2))
end = int(center_x + (width // 2) + 1)
kernel[row, start:end] = 1
# Vertical extension part
if cone_height > 0:
base_width = int(kernel[cone_height-1].sum())
else:
base_width = 1
for row in range(cone_height, total_size):
start = max(0, int(center_x - (base_width//2)))
end = min(total_size, int(center_x + (base_width//2) + 1))
kernel[row, start:end] = 1
self.kernel = kernel
# Modify cheek erosion kernel to be flatter ellipse
self.cheek_kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (35, 3))
# Add cheek area mask (protect chin area)
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
"""Create cheek area mask (1/4 area on both sides)"""
mask = np.zeros((512, 512), dtype=np.uint8)
center = 512 // 2
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
return mask
def model_init(self,
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
@@ -30,7 +74,7 @@ class FaceParsing():
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def __call__(self, image, size=(512, 512)):
def __call__(self, image, size=(512, 512), mode="jaw"):
if isinstance(image, str):
image = Image.open(image)
@@ -44,8 +88,25 @@ class FaceParsing():
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing>13)] = 0
parsing[np.where(parsing>=1)] = 255
# Add 14:neck, remove 10:nose and 7:8:9
if mode == "neck":
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
parsing[np.where(parsing!=255)] = 0
elif mode == "jaw":
face_region = np.isin(parsing, [1])*255
face_region = face_region.astype(np.uint8)
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
parsing[np.isin(parsing, [11, 12, 13])] = 255
parsing[np.where(parsing!=255)] = 0
else:
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
parsing[np.where(parsing!=255)] = 0
parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing
Regular → Executable
View File
Regular → Executable
+28 -14
View File
@@ -15,13 +15,24 @@ from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model():
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
unet = UNet(unet_config="./models/musetalk/musetalk.json",
model_path ="./models/musetalk/pytorch_model.bin")
def load_all_model(
unet_model_path="./models/musetalk/pytorch_model.bin",
vae_type="sd-vae-ft-mse",
unet_config="./models/musetalk/musetalk.json",
device=None,
):
vae = VAE(
model_path = f"./models/{vae_type}/",
)
print(f"load unet model from {unet_model_path}")
unet = UNet(
unet_config=unet_config,
model_path=unet_model_path,
device=device
)
pe = PositionalEncoding(d_model=384)
return audio_processor,vae,unet,pe
return vae, unet, pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
@@ -39,10 +50,13 @@ def get_video_fps(video_path):
video.release()
return fps
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
def datagen(
whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0,
device="cuda:0",
):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
@@ -51,14 +65,14 @@ def datagen(whisper_chunks,
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.stack(whisper_batch)
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
yield whisper_batch.to(device), latent_batch.to(device)
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File
View File