14 Commits

Author SHA1 Message Date
czk32611 4724ed8a78 Merge pull request #198 from ShowLo/patch-2
Fix the bug that causes an infinite loop when the total number of fra…
2024-09-23 13:28:35 +08:00
czk32611 b3c30c3561 Merge pull request #197 from ShowLo/patch-1
fix bug in preprocessing data
2024-09-23 13:27:15 +08:00
ShowLo 4f69a9cfdd Fix the bug that causes an infinite loop when the total number of frames in the video does not exceed 11.
eg, the video has 11 frames, when select the NO.6 frame, `while abs(random_element - img_idx) <= 5:` will result in an infinite loop
2024-09-19 17:09:35 +08:00
ShowLo 5bd772d7da fix bug in preprocessing data 2024-09-19 09:56:18 +08:00
czk32611 98f0e6f2b1 Fixed bug in train.py where pe was missing 2024-08-08 14:56:25 +08:00
czk32611 1de8261491 Merge pull request #85 from shounakb1/train_codes
initial data script
2024-08-06 18:49:07 +08:00
Shounak Banerjee b968548131 fixed mltiple video data preperation 2024-06-17 18:39:15 +00:00
Shounak Banerjee af82f3b00f temporary commit to save changes 2024-06-13 14:14:52 +00:00
Shounak Banerjee d74c4c098b clean code and sepaarate finetuned_inference.py 2024-06-07 18:39:24 +00:00
Shounak Banerjee b4a592d7f3 modified dataloader.py and inference.py for training and inference 2024-06-03 11:09:12 +00:00
czk32611 6d19f3c0c8 Remove crop_audio_window from DataLoader.py 2024-06-01 22:23:47 +08:00
shounak 7254ca6306 initial data script 2024-05-16 18:24:44 +00:00
czk32611 30dcd5237f Update train_codes/README.md 2024-04-30 15:10:03 +08:00
czk32611 d73daf1808 Update draft training codes 2024-04-28 18:04:22 +08:00
14 changed files with 2096 additions and 2 deletions
+2 -1
View File
@@ -29,6 +29,7 @@ We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+
- [04/02/2024] Release MuseTalk project and pretrained models. - [04/02/2024] Release MuseTalk project and pretrained models.
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant) - [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
- [04/17/2024] :mega: We release a pipeline that utilizes MuseTalk for real-time inference. - [04/17/2024] :mega: We release a pipeline that utilizes MuseTalk for real-time inference.
- [04/30/2024] We release an initial version of training codes in `train_codes`.
## Model ## Model
![Model Structure](assets/figs/musetalk_arc.jpg) ![Model Structure](assets/figs/musetalk_arc.jpg)
@@ -165,7 +166,7 @@ Note that although we use a very similar architecture as Stable Diffusion, MuseT
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk). - [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
- [x] codes for real-time inference. - [x] codes for real-time inference.
- [ ] technical report. - [ ] technical report.
- [ ] training codes. - [x] training codes.
- [ ] a better model (may take longer). - [ ] a better model (may take longer).
Executable
+77
View File
@@ -0,0 +1,77 @@
#!/bin/bash
# Function to extract video and audio sections
extract_sections() {
input_video=$1
base_name=$(basename "$input_video" .mp4)
output_dir=$2
split=$3
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
IFS=: read -r hours minutes seconds <<< "$duration"
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
chunk_size=180 # 3 minutes in seconds
index=0
mkdir -p "$output_dir"
while [ $((index * chunk_size)) -lt $total_seconds ]; do
start_time=$((index * chunk_size))
section_video="${output_dir}/${base_name}_part${index}.mp4"
section_audio="${output_dir}/${base_name}_part${index}.mp3"
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
# Create and update the config.yaml file
echo "task_0:" > config.yaml
echo " video_path: \"$section_video\"" >> config.yaml
echo " audio_path: \"$section_audio\"" >> config.yaml
# Run the Python script with the current config.yaml
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
index=$((index + 1))
done
# Clean up save folder
rm -rf $output_dir
}
# Main script
if [ $# -lt 3 ]; then
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
exit 1
fi
split=$1
output_dir=$2
shift 2
input_videos=("$@")
# Initialize JSON array
json_array="["
for input_video in "${input_videos[@]}"; do
base_name=$(basename "$input_video" .mp4)
# Extract sections and run the Python script for each section
extract_sections "$input_video" "$output_dir" "$split"
# Add entry to JSON array
json_array+="\"../data/images/$base_name\","
done
# Remove trailing comma and close JSON array
json_array="${json_array%,}]"
# Write JSON array to the correct file
if [ "$split" == "train" ]; then
echo "$json_array" > train.json
elif [ "$split" == "test" ]; then
echo "$json_array" > test.json
else
echo "Invalid split: $split. Must be 'train' or 'test'."
exit 1
fi
echo "Processing complete."
+257
View File
@@ -0,0 +1,257 @@
import cv2
import os
# import dlib
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
from tqdm import tqdm
import copy
import uuid
from musetalk.utils.utils import get_file_type,get_video_fps
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
import gc
# load model weights
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
def get_largest_integer_filename(folder_path):
# Check if the folder exists
if not os.path.isdir(folder_path):
return -1
# Get the list of files in the folder
files = os.listdir(folder_path)
# Check if the folder is empty
if not files:
return -1
# Extract the integer part of filenames and find the largest
largest_integer = -1
for file in files:
try:
# Get the integer part of the filename
file_int = int(os.path.splitext(file)[0])
if file_int > largest_integer:
largest_integer = file_int
except ValueError:
# Skip files that don't have an integer filename
continue
return largest_integer
def datagen(whisper_chunks,
crop_images,
batch_size=8,
delay_frame=0):
whisper_batch, crop_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(crop_images)
crop_image = crop_images[idx]
whisper_batch.append(w)
crop_batch.append(crop_image)
if len(crop_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
# latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, crop_batch
whisper_batch, crop_batch = [], []
# the last batch may smaller than batch size
if len(crop_batch) > 0:
whisper_batch = np.stack(whisper_batch)
# latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, crop_batch
@torch.no_grad()
def main(args):
global pe
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
temp_audio_index=total_audio_index
temp_image_index=total_image_index
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
folder_name = args.folder_name
if not os.path.exists(f"data/images/{folder_name}/"):
os.makedirs(f"data/images/{folder_name}")
if not os.path.exists(f"data/audios/{folder_name}/"):
os.makedirs(f"data/audios/{folder_name}")
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
# Combine two consecutive chunks
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
# Save the pair to a .npy file
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
temp_audio_index=(__//2)+total_audio_index+1
total_audio_index=temp_audio_index
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
############################################## preprocess input image ##############################################
gc.collect()
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
crop_i=0
crop_data=[]
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
x1=max(0,x1)
y1=max(0,y1)
x2=max(0,x2)
y2=max(0,y2)
if ((y2-y1)<=0) or ((x2-x1)<=0):
continue
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
crop_data.append(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
crop_data = crop_data + crop_data[::-1]
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,crop_data,batch_size)
crop_index = 0
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
for image,audio in zip(crop_batch,whisper_batch):
cv2.imwrite(f"data/images/{folder_name}/{str(crop_index+total_image_index+1)}.png",image)
temp_image_index = crop_index + total_image_index + 1
crop_index += 1
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
total_image_index=temp_image_index
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
args = parser.parse_args()
main(args)
def process_audio(audio_path):
whisper_feature = audio_processor.audio2feat(audio_path)
np.save('audio/your_filename.npy', whisper_feature)
def mask_face(image):
# Load dlib's face detector and the landmark predictor
detector = dlib.get_frontal_face_detector()
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
predictor = dlib.shape_predictor(predictor_path)
# Load your input image
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
# image = cv2.imread(image_path)
if image is None:
raise ValueError("Image not found or unable to load.")
# Convert to grayscale for detection
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Detect faces in the image
faces = detector(gray)
# Process each detected face
for face in faces:
# Predict landmarks
landmarks = predictor(gray, face)
# The indices of nose landmarks are 27 to 35
nose_tip = landmarks.part(33).y
# Blacken the region below the nose tip
blacken_area = image[nose_tip:, :]
blacken_area[:] = (0, 0, 0)
# Save the final image or display it
# cv2.imwrite("output_image.jpg", image)
return image
+182
View File
@@ -0,0 +1,182 @@
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
from tqdm import tqdm
import copy
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
from accelerate import Accelerator
# load model weights
audio_processor, vae, unet, pe = load_all_model()
accelerator = Accelerator(
mixed_precision="fp16",
)
unet = accelerator.prepare(
unet,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
if not (args.unet_checkpoint == None):
print("unet ckpt loaded")
accelerator.load_state(args.unet_checkpoint)
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
crop_i=0
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
############################################## pad to full image ##############################################
print("pad talking image to original video")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
combine_frame = get_image(ori_frame,res_frame,bbox)
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
os.system(cmd_combine_audio)
os.remove("temp.mp4")
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
# shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--unet_checkpoint", type=str, default=None)
args = parser.parse_args()
main(args)
+1 -1
View File
@@ -158,4 +158,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
+235
View File
@@ -0,0 +1,235 @@
import os, random, cv2, argparse
import torch
from torch.utils import data as data_utils
from os.path import dirname, join, basename, isfile
import numpy as np
from glob import glob
from utils.utils import prepare_mask_and_masked_image
import torchvision.utils as vutils
import torchvision.transforms as transforms
import shutil
from tqdm import tqdm
import ast
import json
import re
import heapq
syncnet_T = 1
RESIZED_IMG = 256
connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线
(17, 18), (18, 19), (19, 20), (20, 21), #左眉毛
(22, 23), (23, 24), (24, 25), (25, 26), #右眉毛
(27, 28),(28,29),(29,30),# 鼻梁
(31,32),(32,33),(33,34),(34,35), #鼻子
(36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼
(42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼
(48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延
(54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延
(60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈
]
def get_image_list(data_root, split):
filelist = []
imgNumList = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line:
filename = line.split()[0]
imgNum = int(line.split()[1])
filelist.append(os.path.join(data_root, filename))
imgNumList.append(imgNum)
return filelist, imgNumList
class Dataset(object):
def __init__(self,
data_root,
json_path,
use_audio_length_left=1,
use_audio_length_right=1,
whisper_model_type = "tiny"
):
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
# self.split = split
self.img_names_path = '../data'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny":
self.whisper_path = '../data/audios'
self.whisper_feature_W = 5
self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2":
self.whisper_path = '...'
self.whisper_feature_W = 33
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
with open(json_path, 'r') as file:
self.all_videos = json.load(file)
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
if not os.path.exists(json_path_names):
img_names = glob(join(vidname, '*.png'))
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f:
json.dump(img_names,f)
else:
with open(json_path_names, "r") as f:
img_names = json.load(f)
self.all_img_names.append(img_names)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.png'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG))
except Exception as e:
print("read_window has error fname not exist:",fname)
return None
window.append(img)
return window
def prepare_window(self, window):
# 1 x H x W x 3
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
#随机选择某个video里
vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx]
if len(video_imgs) <= 11:
continue
img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0])
random_element = random.randint(0,len(video_imgs)-1)
while abs(random_element - img_idx) <= 5:
random_element = random.randint(0,len(video_imgs)-1)
img_dir = os.path.dirname(img_name)
ref_image = os.path.join(img_dir, f"{str(random_element)}.png")
target_window_fnames = self.get_window(img_name)
ref_window_fnames = self.get_window(ref_image)
if target_window_fnames is None or ref_window_fnames is None:
print("No such img",img_name, ref_image)
continue
try:
#构建目标img数据
target_window = self.read_window(target_window_fnames)
if target_window is None :
print("No such target window,",target_window_fnames)
continue
#构建参考img数据
ref_window = self.read_window(ref_window_fnames)
if ref_window is None:
print("No such target ref window,",ref_window)
continue
except Exception as e:
print(f"发生未知错误:{e}")
continue
#构建target输入
target_window = self.prepare_window(target_window)
image = gt = target_window.copy().squeeze()
target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1:输入
ref_image = self.prepare_window(ref_window).squeeze()
mask = torch.zeros((ref_image.shape[1], ref_image.shape[2]))
mask[:ref_image.shape[2]//2,:] = 1
image = torch.FloatTensor(image)
mask, masked_image = prepare_mask_and_masked_image(image,mask)
#音频特征
window_index = self.get_frame_id(img_name)
sub_folder_name = vidname.split('/')[-1]
## 根据window_index加载相邻的音频
audio_feature_all = []
is_index_out_of_range = False
if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path):
is_index_out_of_range = True
break
try:
audio_feature_all.append(np.load(audio_feat_path))
except Exception as e:
print(f"发生未知错误:{e}")
print(f"npy load error {audio_feat_path}")
if is_index_out_of_range:
continue
audio_feature = np.concatenate(audio_feature_all, axis=0)
else:
continue
audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1 -1 384
if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature
if __name__ == "__main__":
data_root = '...'
val_data = Dataset(data_root,
'val',
use_audio_length_left = 2,
use_audio_length_right = 2,
whisper_model_type = "tiny"
)
val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True,
num_workers=1)
for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data
+52
View File
@@ -0,0 +1,52 @@
# Data preprocessing
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
The train yaml file should contain the training video paths and corresponding audio paths
The test yaml file should contain the validation video paths and corresponding audio paths
Run:
```
./data_new.sh train output train_video1.mp4 train_video2.mp4
./data_new.sh test output test_video1.mp4 test_video2.mp4
```
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
## Data organization
```
./data/
├── images
│ └──RD_Radio10_000
│ └── 0.png
│ └── 1.png
│ └── xxx.png
│ └──RD_Radio11_000
│ └── 0.png
│ └── 1.png
│ └── xxx.png
├── audios
│ └──RD_Radio10_000
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
│ └──RD_Radio11_000
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
```
## Training
Simply run after preparing the preprocessed data
```
cd train_codes
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
#--val_json="../val.json" \
```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
```
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
```
## TODO
- [x] release data preprocessing codes
- [ ] release some novel designs in training (after technical report)
+322
View File
@@ -0,0 +1,322 @@
RD_Radio10_000 3501
RD_Radio11_000 752
RD_Radio11_001 1502
RD_Radio12_000 876
RD_Radio13_000 877
RD_Radio14_000 1251
RD_Radio16_000 2377
RD_Radio17_000 2176
RD_Radio18_000 1827
RD_Radio19_000 1876
RD_Radio1_000 1876
RD_Radio20_000 276
RD_Radio21_000 1201
RD_Radio22_000 752
RD_Radio23_000 1602
RD_Radio25_000 2126
RD_Radio26_000 1052
RD_Radio27_000 1376
RD_Radio28_000 2252
RD_Radio29_000 2252
RD_Radio2_000 2076
RD_Radio30_000 2877
RD_Radio31_000 1377
RD_Radio32_000 1502
RD_Radio33_000 2252
RD_Radio34_000 702
RD_Radio34_001 252
RD_Radio34_002 501
RD_Radio34_003 326
RD_Radio34_004 502
RD_Radio34_005 502
RD_Radio34_006 502
RD_Radio34_007 252
RD_Radio34_008 876
RD_Radio34_009 752
RD_Radio35_000 2127
RD_Radio36_000 2377
RD_Radio37_000 3127
RD_Radio38_000 3752
RD_Radio39_000 1502
RD_Radio3_000 2377
RD_Radio40_000 1252
RD_Radio41_000 2127
RD_Radio42_000 1752
RD_Radio43_000 1502
RD_Radio44_000 1127
RD_Radio45_000 1628
RD_Radio46_000 1877
RD_Radio47_000 1502
RD_Radio48_000 2127
RD_Radio49_000 1502
RD_Radio4_000 1002
RD_Radio50_000 2252
RD_Radio51_000 3253
RD_Radio52_000 2752
RD_Radio53_000 2627
RD_Radio54_000 577
RD_Radio56_000 1952
RD_Radio57_000 3077
RD_Radio59_000 1952
RD_Radio5_000 1377
RD_Radio7_000 2252
RD_Radio8_000 2126
RD_Radio9_000 1502
WDA_AdamSchiff_000 6877
WDA_AdamSmith_000 7327
WDA_AlexandriaOcasioCortez_000 2252
WDA_AmyKlobuchar0_000 6501
WDA_AmyKlobuchar1_000 3253
WDA_AmyKlobuchar1_001 877
WDA_AmyKlobuchar1_002 1502
WDA_AmyKlobuchar1_003 1377
WDA_AndyKim_000 6952
WDA_AndyLevin_000 4876
WDA_AnnieKuster_000 4876
WDA_BarackObama_000 3575
WDA_BarackObama_001 5625
WDA_BarbaraLee0_000 6502
WDA_BarbaraLee1_000 4702
WDA_BenCardin0_000 8127
WDA_BenCardin1_000 7102
WDA_BenRayLujn_000 7127
WDA_BennieThompson1_000 6375
WDA_BennieThompson_000 5375
WDA_BernieSanders_000 8451
WDA_BettyMcCollum_000 7002
WDA_BillPascrell_000 7252
WDA_BillRichardson_000 3127
WDA_BobCasey0_000 10627
WDA_BobCasey1_000 252
WDA_BobMenendez_000 3002
WDA_BobbyScott_000 2002
WDA_BradSchneider_000 6627
WDA_BrendaLawrence_000 5627
WDA_BrianSchatz0_000 502
WDA_BrianSchatz1_000 2627
WDA_BrianSchatz2_000 1952
WDA_ByronDorgan1_000 4277
WDA_CarolynMaloney1_000 8377
WDA_CatherineCortezMasto_000 2827
WDA_CedricRichmond_000 7250
WDA_ChrisCoons1_000 4452
WDA_ChrisCoons_000 6827
WDA_ChrisMurphy0_000 6877
WDA_ChrisMurphy1_000 6377
WDA_ChrisVanHollen0_000 8202
WDA_ChrisVanHollen1_000 7327
WDA_ChuckSchumer0_000 5577
WDA_ChuckSchumer1_000 4827
WDA_ColinAllred_000 4751
WDA_DanKildee1_000 6252
WDA_DanKildee_000 2502
WDA_DavidCicilline_000 5875
WDA_DebHaaland_000 5876
WDA_DebbieDingell0_000 7752
WDA_DebbieDingell1_000 7251
WDA_DebbieStabenow0_000 4577
WDA_DebbieStabenow1_000 5201
WDA_DebbieWassermanSchultz_000 7625
WDA_DianaDeGette0_000 6002
WDA_DianaDeGette1_000 2202
WDA_DianneFeinstein_000 7453
WDA_DickDurbin_000 6326
WDA_DonaldMcEachin_000 7627
WDA_DonnaShalala1_000 7500
WDA_DougJones_000 6077
WDA_EdMarkey0_000 4752
WDA_EdMarkey1_000 6501
WDA_ElijahCummings_000 6377
WDA_EliotEngel_000 6876
WDA_EmanuelCleaver_000 7001
WDA_EricSwalwell_000 6377
WDA_FrankPallone0_000 5625
WDA_FrankPallone1_000 6502
WDA_GerryConnolly_000 6752
WDA_HakeemJeffries_000 6002
WDA_HaleyStevens_000 5001
WDA_HenryWaxman_000 1125
WDA_HillaryClinton_000 2500
WDA_JackReed0_000 2377
WDA_JackReed1_000 4877
WDA_JackieSpeier_000 4625
WDA_JackyRosen1_000 10077
WDA_JackyRosen_000 5502
WDA_JamesClyburn1_000 6876
WDA_JamesClyburn_000 6875
WDA_JanSchakowsky0_000 4128
WDA_JanSchakowsky1_000 6251
WDA_JeanneShaheen0_000 6702
WDA_JeanneShaheen1_000 6577
WDA_JeffMerkley1_000 6952
WDA_JerryNadler_000 8377
WDA_JimHimes_000 7000
WDA_JimmyGomez_000 4627
WDA_JoaquinCastro_000 5126
WDA_JoeCrowley0_000 4877
WDA_JoeCrowley1_000 1127
WDA_JoeCrowley1_001 627
WDA_JoeCrowley1_002 502
WDA_JoeCrowley1_003 752
WDA_JoeDonnelly_000 1377
WDA_JoeKennedy_000 1327
WDA_JoeManchin_000 3377
WDA_JoeNeguse_000 1628
WDA_JoeNeguse_001 1126
WDA_JoeNeguse_002 1251
WDA_JohnLewis0_000 6252
WDA_JohnLewis1_000 7252
WDA_JohnSarbanes0_000 6377
WDA_JohnSarbanes1_000 1877
WDA_JohnYarmuth1_000 5377
WDA_JonTester0_000 3252
WDA_JonTester1_000 5327
WDA_KarenBass_000 7126
WDA_KatherineClark_000 1552
WDA_KathyCastor1_000 6001
WDA_KathyCastor_000 2252
WDA_KimSchrier_000 6877
WDA_KirstenGillibrand_000 9627
WDA_LaurenUnderwood_000 9877
WDA_LisaBluntRochester_000 4500
WDA_LloydDoggett0_000 7377
WDA_LloydDoggett1_000 2252
WDA_LucilleRoybal-Allard_000 2877
WDA_LucyMcBath_000 4000
WDA_MarciaFudge_000 7377
WDA_MarkWarner1_000 377
WDA_MarkWarner1_001 1327
WDA_MarkWarner2_000 377
WDA_MarkWarner_000 3127
WDA_MartinHeinrich_000 5202
WDA_MattCartwright_000 5377
WDA_MazieHirono0_000 3752
WDA_MichelleLujanGrisham_000 6875
WDA_MichelleObama_000 2000
WDA_MikeDoyle_000 8750
WDA_MikeThompson0_000 4625
WDA_MikeThompson1_000 1827
WDA_NancyPelosi0_000 10251
WDA_NancyPelosi1_000 1377
WDA_NancyPelosi3_000 1127
WDA_NitaLowey_000 5876
WDA_NydiaVelzquez_000 5500
WDA_PatrickLeahy0_000 6951
WDA_PatrickLeahy1_000 9953
WDA_PattyMurray0_000 3627
WDA_PattyMurray1_000 5702
WDA_PeterDeFazio_000 6628
WDA_RaulRuiz_000 5752
WDA_RichardBlumenthal_000 7827
WDA_RichardNeal0_000 6252
WDA_RichardNeal1_000 6877
WDA_RobinKelly_000 4377
WDA_RonWyden0_000 5152
WDA_RonWyden1_000 8327
WDA_ScottPeters0_000 7002
WDA_ScottPeters1_000 3952
WDA_SeanCasten_000 6126
WDA_SeanPatrickMaloney_000 6502
WDA_SheldonWhitehouse0_000 6327
WDA_SheldonWhitehouse1_000 5702
WDA_SherrodBrown0_000 7452
WDA_SherrodBrown1_000 7327
WDA_StenyHoyer_000 3502
WDA_StephanieMurphy_000 4000
WDA_SuzanDelBene_000 6875
WDA_TammyBaldwin0_000 5327
WDA_TammyBaldwin1_000 2503
WDA_TammyDuckworth_000 5702
WDA_TedLieu_000 6877
WDA_TerriSewell0_000 2127
WDA_TerriSewell1_000 10952
WDA_TerriSewell_000 6752
WDA_TimWalz_000 6628
WDA_TinaSmith_000 5576
WDA_TomCarper_000 4701
WDA_TomPerez_000 5500
WDA_TomUdall_000 4077
WDA_VeronicaEscobar0_000 4377
WDA_VeronicaEscobar1_000 2453
WDA_WhipJimClyburn_000 6876
WDA_XavierBecerra_000 877
WDA_XavierBecerra_001 627
WDA_XavierBecerra_002 1377
WDA_ZoeLofgren_000 4625
WRA_AdamKinzinger0_000 4251
WRA_AdamKinzinger1_000 2751
WRA_AdamKinzinger2_000 1002
WRA_AdamKinzinger2_001 1001
WRA_AdamKinzinger2_002 502
WRA_AllenWest_000 6876
WRA_AnnMarieBuerkle_000 452
WRA_AnnWagner_000 3127
WRA_AustinScott0_000 1177
WRA_BillCassidy0_000 427
WRA_BobCorker_000 1127
WRA_BobGoodlatte0_000 1001
WRA_BobGoodlatte0_001 752
WRA_BobGoodlatte0_002 876
WRA_BobGoodlatte0_003 1002
WRA_BobbySchilling_001 1000
WRA_BobbySchilling_002 1001
WRA_CandiceMiller0_000 3001
WRA_CarlyFiorina0_000 876
WRA_CarlyFiorina_000 1902
WRA_CathyMcMorrisRodgers0_000 6077
WRA_CathyMcMorrisRodgers1_000 7375
WRA_CathyMcMorrisRodgers1_001 7500
WRA_CathyMcMorrisRodgers1_002 7500
WRA_CathyMcMorrisRodgers2_000 2751
WRA_ChuckGrassley_000 4502
WRA_CoryGardner0_000 8577
WRA_CoryGardner1_000 2002
WRA_CoryGardner_000 1252
WRA_DanNewhouse_000 3627
WRA_DanSullivan_000 6877
WRA_DaveCamp_000 752
WRA_DaveCamp_001 876
WRA_DaveCamp_002 627
WRA_DavidVitter_000 2877
WRA_DeanHeller_000 377
WRA_DebFischer0_000 7577
WRA_DebFischer1_000 1577
WRA_DebFischer2_000 5202
WRA_DianeBlack0_000 451
WRA_DianeBlack0_001 301
WRA_DianeBlack1_000 250
WRA_DuncanHunter_000 727
WRA_EricCantor_000 2952
WRA_ErikPaulsen_000 876
WRA_ErikPaulsen_001 377
WRA_ErikPaulsen_002 627
WRA_ErikPaulsen_003 752
WRA_FredUpton_000 1701
WRA_GeoffDavis_000 250
WRA_GeorgeLeMieux_000 752
WRA_GeorgeLeMieux_001 1001
WRA_GregWalden1_000 752
WRA_GregWalden1_001 1127
WRA_GregWalden1_002 1327
WRA_GregWalden_000 377
WRA_JaimeHerreraBeutler0_000 952
WRA_JebHensarling0_001 1176
WRA_JebHensarling1_000 2327
WRA_JebHensarling1_001 727
WRA_JebHensarling2_000 2302
WRA_JebHensarling2_001 877
WRA_JebHensarling2_002 2127
WRA_JebHensarling2_003 1877
WRA_JeffFlake_000 7202
WRA_JeffFlake_001 7502
WRA_JeffFlake_002 7377
WRA_JimInhofe_000 3127
WRA_JimRisch_000 1377
WRA_JoeHeck1_000 250
WRA_JoePitts_000 1077
WRA_JohnBarrasso0_000 5077
WRA_JohnBarrasso1_000 3452
WRA_JohnBoehner0_000 1626
WRA_JohnBoehner1_000 2951
WRA_JohnBoozman_000 1877
WRA_JohnHoeven_000 2577
+30
View File
@@ -0,0 +1,30 @@
WRA_JohnKasich0_000 1302
WRA_JohnKasich1_000 1301
WRA_JohnKasich1_001 1127
WRA_JohnKasich3_000 1052
WRA_JohnThune_000 1452
WRA_JohnnyIsakson_000 5951
WRA_JohnnyIsakson_001 5000
WRA_JonKyl_000 626
WRA_JoniErnst0_000 2077
WRA_JoniErnst1_000 1452
WRA_JuddGregg_000 1252
WRA_JuddGregg_001 952
WRA_JuddGregg_002 753
WRA_KayBaileyHutchison_000 6325
WRA_KellyAyotte_000 8077
WRA_KevinBrady2_000 1276
WRA_KevinBrady3_000 752
WRA_KevinBrady_000 2503
WRA_KevinMcCarthy0_000 1002
WRA_KevinMcCarthy0_001 1127
WRA_KristiNoem0_000 727
WRA_KristiNoem1_000 452
WRA_KristiNoem2_000 5952
WRA_KristiNoem2_001 7277
WRA_LamarAlexander0_000 1527
WRA_LamarAlexander_000 1552
WRA_LisaMurkowski0_000 1752
WRA_LynnJenkins_000 877
WRA_LynnJenkins_001 1002
WRA_MarcoRubio_000 526
+36
View File
@@ -0,0 +1,36 @@
{
"_class_name": "UNet2DConditionModel",
"_diffusers_version": "0.6.0.dev0",
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"center_input_sample": false,
"cross_attention_dim": 384,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"flip_sin_to_cos": true,
"freq_shift": 0,
"in_channels": 8,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"out_channels": 4,
"sample_size": 64,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
]
}
+670
View File
@@ -0,0 +1,670 @@
import argparse
import itertools
import math
import os
import random
from pathlib import Path
import json
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from PIL import Image, ImageDraw
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
import sys
sys.path.append("./")
from DataLoader import Dataset
from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils
from utils.model_utils import validation,PositionalEncoding
import time
import pandas as pd
from PIL import Image
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--unet_config_file",
type=str,
default=None,
required=True,
help="the configuration of unet file.",
)
parser.add_argument(
"--reconstruction",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--data_root",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--testing_speed", action="store_true", help="Whether to caculate the running time")
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=1000,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
" using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--validation_steps",
type=int,
default=1000,
help=(
"Conduct validation every X updates."
),
)
parser.add_argument(
"--val_out_dir",
type=str,
default = '',
help=(
"Conduct validation every X updates."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--use_audio_length_left",
type=int,
default=1,
help="number of audio length (left).",
)
parser.add_argument(
"--use_audio_length_right",
type=int,
default=1,
help="number of audio length (right)",
)
parser.add_argument(
"--whisper_model_type",
type=str,
default="landmark_nearest",
choices=["tiny","largeV2"],
help="Determine whisper feature type",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def print_model_dtypes(model, model_name):
for name, param in model.named_parameters():
if(param.dtype!=torch.float32):
print(f"{name}: {param.dtype}")
def main():
args = parse_args()
args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.val_out_dir, exist_ok=True)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
logging_dir = Path(args.output_dir, args.logging_dir)
project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
project_config=project_config,
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None:
# set_seed(args.seed)
set_seed(seed + accelerator.process_index)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load models and create wrapper for stable diffusion
with open(args.unet_config_file, 'r') as f:
unet_config = json.load(f)
#text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
# Todo:
print("Loading AutoencoderKL")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
vae_fp32 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
print("Loading UNet2DConditionModel")
unet = UNet2DConditionModel(**unet_config)
if args.whisper_model_type == "tiny":
pe = PositionalEncoding(d_model=384)
elif args.whisper_model_type == "largeV2":
pe = PositionalEncoding(d_model=1280)
else:
print(f"not support whisper_model_type {args.whisper_model_type}")
print("Loading models done...")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters()))
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
print("loading train_dataset ...")
train_dataset = Dataset(args.data_root,
args.train_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8)
print("loading val_dataset ...")
val_dataset = Dataset(args.data_root,
args.val_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False,
num_workers=8)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
unet, optimizer, train_data_loader, val_data_loader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_data_loader, val_data_loader,lr_scheduler
)
vae.requires_grad_(False)
vae_fp32.requires_grad_(False)
weight_dtype = torch.float32
# weight_dtype = torch.float16
vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae.to(accelerator.device, dtype=weight_dtype)
vae.decoder = None
pe.to(accelerator.device, dtype=weight_dtype)
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print(f" Num batches each epoch = {len(train_data_loader)}")
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
# path="../models/pytorch_model.bin"
#TODO change path
# path=None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
# caluate the elapsed time
elapsed_time = []
start = time.time()
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
dataloader_time = time.time() - start
start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
# """
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
# print("ref_image: ",ref_image.shape)
# print("masks: ", masks.shape)
# print("masked_image: ", masked_image.shape)
# print("audio feature: ", audio_feature.shape)
# print("image: ", image.shape)
# """
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start
start = time.time()
with accelerator.accumulate(unet):
vae = vae.half()
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
# Convert ref images to latent space
ref_latents = vae.encode(
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
).latent_dist.sample()
ref_latents = ref_latents * vae.config.scaling_factor
vae_time = time.time() - start
start = time.time()
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
bsz = latents.shape[0]
# fix timestep for each image
timesteps = torch.tensor([0], device=latents.device)
# concatenate the latents with the mask and the masked latents
"""
print("=============vae latents=====".format(epoch,step))
print("ref_latents: ",ref_latents.shape)
print("mask: ", mask.shape)
print("masked_latents: ", masked_latents.shape)
"""
if unet_config['in_channels'] == 9:
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
else:
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
audio_feature = audio_feature.to(dtype=weight_dtype)
audio_feature = pe(audio_feature)
# Predict the noise residual
image_pred = unet(latent_model_input,
timesteps,
encoder_hidden_states=audio_feature).sample
if args.reconstruction: # decode the image from the predicted latents
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
image_pred_img = vae_fp32.decode(image_pred_img).sample
# Mask the top half of the image and calculate the loss only for the lower half of the image.
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
image = image[:, :, image.shape[2]//2:, :]
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images
loss_latents = F.l1_loss(image_pred.float(), latents.float(), reduction="mean") # the loss of the latents
loss = 2.0*loss_lip + loss_latents # add some weight to balance the loss
else:
loss = F.mse_loss(image_pred.float(), latents.float(), reduction="mean")
#
unet_elapsed_time = time.time() - start
start = time.time()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters())
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
backward_elapsed_time = time.time() - start
start = time.time()
if args.testing_speed is True and accelerator.is_main_process:
elapsed_time.append(
[dataloader_time, unet_elapsed_time, vae_time, backward_elapsed_time,img_process_time]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if global_step % args.validation_steps == 0:
if accelerator.is_main_process:
logger.info(
f"Running validation... epoch={epoch}, global_step={global_step}"
)
print("===========start validation==========")
# Use the helper function to check the data types for each model
vae_new = vae.float()
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
print(f"weight_dtype: {weight_dtype}")
print(f"epoch type: {type(epoch)}")
print(f"global_step type: {type(global_step)}")
validation(
# vae=accelerator.unwrap_model(vae),
vae=accelerator.unwrap_model(vae_new),
vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet),
unet_config=unet_config,
# weight_dtype=weight_dtype,
weight_dtype=torch.float32,
epoch=epoch,
global_step=global_step,
val_data_loader=val_data_loader,
output_dir = args.val_out_dir,
whisper_model_type = args.whisper_model_type
)
logger.info(f"Saved samples to images/val")
start = time.time()
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0],
"unet": unet_elapsed_time,
"backward": backward_elapsed_time,
"data": dataloader_time,
"img_process":img_process_time,
"vae":vae_time
}
progress_bar.set_postfix(**logs)
# accelerator.log(logs, step=global_step)
accelerator.log(
{
"loss/step_loss": logs["loss"],
"parameter/lr": logs["lr"],
"time/unet_forward_time": unet_elapsed_time,
"time/unet_backward_time": backward_elapsed_time,
"time/data_time": dataloader_time,
"time/img_process_time":img_process_time,
"time/vae_time": vae_time
},
step=global_step,
)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
main()
+29
View File
@@ -0,0 +1,29 @@
export VAE_MODEL="../models/sd-vae-ft-mse/"
export DATASET="../data"
export UNET_CONFIG="../models/musetalk/musetalk.json"
accelerate launch train.py \
--mixed_precision="fp16" \
--unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \
--data_root=$DATASET \
--train_batch_size=256 \
--gradient_accumulation_steps=16 \
--gradient_checkpointing \
--max_train_steps=100000 \
--learning_rate=5e-05 \
--max_grad_norm=1 \
--lr_warmup_steps=0 \
--output_dir="output" \
--val_out_dir='val' \
--testing_speed \
--checkpointing_steps=2000 \
--validation_steps=2000 \
--reconstruction \
--resume_from_checkpoint="latest" \
--use_audio_length_left=2 \
--use_audio_length_right=2 \
--whisper_model_type="tiny" \
--train_json="../train.json" \
--val_json="../val.json" \
--lr_scheduler="cosine" \
+129
View File
@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import time
import math
from utils.utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from torch import Tensor, nn
import logging
import json
RESIZED_IMG = 256
class PositionalEncoding(nn.Module):
"""
Transformer 中的位置编码(positional encoding
"""
def __init__(self, d_model=384, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
b, seq_len, d_model = x.size()
pe = self.pe[:, :seq_len, :]
#print(b, seq_len, d_model)
x = x + pe.to(x.device)
return x
def validation(vae: torch.nn.Module,
vae_fp32: torch.nn.Module,
unet:torch.nn.Module,
unet_config,
weight_dtype: torch.dtype,
epoch: int,
global_step: int,
val_data_loader,
output_dir,
whisper_model_type,
UNet2DConditionModel=UNet2DConditionModel
):
# Get the validation pipeline
unet_copy = UNet2DConditionModel(**unet_config)
unet_copy.load_state_dict(unet.state_dict())
unet_copy.to(vae.device).to(dtype=weight_dtype)
unet_copy.eval()
if whisper_model_type == "tiny":
pe = PositionalEncoding(d_model=384)
elif whisper_model_type == "largeV2":
pe = PositionalEncoding(d_model=1280)
elif whisper_model_type == "tiny-conv":
pe = PositionalEncoding(d_model=384)
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
else:
print(f"not support whisper_model_type {whisper_model_type}")
pe.to(vae.device, dtype=weight_dtype)
start = time.time()
with torch.no_grad():
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
# Convert ref images to latent space
ref_latents = vae.encode(
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
).latent_dist.sample()
ref_latents = ref_latents * vae.config.scaling_factor
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
bsz = latents.shape[0]
timesteps = torch.tensor([0], device=latents.device)
if unet_config['in_channels'] == 9:
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
else:
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
val_img_dir = f"images/{output_dir}/{global_step}"
if not os.path.exists(val_img_dir):
os.makedirs(val_img_dir)
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))
+74
View File
@@ -0,0 +1,74 @@
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
import torch
import torchvision.transforms as transforms
from diffusers import AutoencoderKL
import matplotlib.pyplot as plt
import PIL
import os
import cv2
from glob import glob
def preprocess_img_tensor(image_tensor):
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
N, C, H, W = image_tensor.shape
# 计算新的宽度和高度,使其为 32 的整数倍
new_w = W - W % 32
new_h = H - H % 32
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# 对每个图像应用变换,并将结果存储在一个新的张量中
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
for i in range(N):
# 使用 F.interpolate 替换 transforms.Resize
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
preprocessed_images[i] = transform(resized_image.squeeze(0))
return preprocessed_images
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
# # 对图像张量进行归一化
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
# # 对遮罩张量进行归一化和二值化
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
# 创建遮罩后的图像
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
return mask_tensor, masked_image_tensor
def encode_latents(vae, image):
# init_image = preprocess_image(image)
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = 0.18215 * init_latent_dist.sample()
return init_latents
def decode_latents(vae, latents, ref_images=None):
latents = (1/ 0.18215) * latents
image = vae.decode(latents.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
if ref_images is not None:
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
ref_images = (ref_images * 255).round().astype("uint8")
h = image.shape[1]
image[:, :h//2] = ref_images[:, :h//2]
image = [Image.fromarray(im) for im in image]
return image[0]