Compare commits
14 Commits
windows
...
train_codes
| Author | SHA1 | Date | |
|---|---|---|---|
| 4724ed8a78 | |||
| b3c30c3561 | |||
| 4f69a9cfdd | |||
| 5bd772d7da | |||
| 98f0e6f2b1 | |||
| 1de8261491 | |||
| b968548131 | |||
| af82f3b00f | |||
| d74c4c098b | |||
| b4a592d7f3 | |||
| 6d19f3c0c8 | |||
| 7254ca6306 | |||
| 30dcd5237f | |||
| d73daf1808 |
@@ -29,6 +29,7 @@ We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+
|
||||
- [04/02/2024] Release MuseTalk project and pretrained models.
|
||||
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
|
||||
- [04/17/2024] :mega: We release a pipeline that utilizes MuseTalk for real-time inference.
|
||||
- [04/30/2024] We release an initial version of training codes in `train_codes`.
|
||||
|
||||
## Model
|
||||

|
||||
@@ -165,7 +166,7 @@ Note that although we use a very similar architecture as Stable Diffusion, MuseT
|
||||
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
|
||||
- [x] codes for real-time inference.
|
||||
- [ ] technical report.
|
||||
- [ ] training codes.
|
||||
- [x] training codes.
|
||||
- [ ] a better model (may take longer).
|
||||
|
||||
|
||||
|
||||
Executable
+77
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Function to extract video and audio sections
|
||||
extract_sections() {
|
||||
input_video=$1
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
output_dir=$2
|
||||
split=$3
|
||||
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
|
||||
IFS=: read -r hours minutes seconds <<< "$duration"
|
||||
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
|
||||
chunk_size=180 # 3 minutes in seconds
|
||||
index=0
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
||||
while [ $((index * chunk_size)) -lt $total_seconds ]; do
|
||||
start_time=$((index * chunk_size))
|
||||
section_video="${output_dir}/${base_name}_part${index}.mp4"
|
||||
section_audio="${output_dir}/${base_name}_part${index}.mp3"
|
||||
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
|
||||
|
||||
# Create and update the config.yaml file
|
||||
echo "task_0:" > config.yaml
|
||||
echo " video_path: \"$section_video\"" >> config.yaml
|
||||
echo " audio_path: \"$section_audio\"" >> config.yaml
|
||||
|
||||
# Run the Python script with the current config.yaml
|
||||
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
|
||||
|
||||
index=$((index + 1))
|
||||
done
|
||||
|
||||
# Clean up save folder
|
||||
rm -rf $output_dir
|
||||
}
|
||||
|
||||
# Main script
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
split=$1
|
||||
output_dir=$2
|
||||
shift 2
|
||||
input_videos=("$@")
|
||||
|
||||
# Initialize JSON array
|
||||
json_array="["
|
||||
|
||||
for input_video in "${input_videos[@]}"; do
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
|
||||
# Extract sections and run the Python script for each section
|
||||
extract_sections "$input_video" "$output_dir" "$split"
|
||||
|
||||
# Add entry to JSON array
|
||||
json_array+="\"../data/images/$base_name\","
|
||||
done
|
||||
|
||||
# Remove trailing comma and close JSON array
|
||||
json_array="${json_array%,}]"
|
||||
|
||||
# Write JSON array to the correct file
|
||||
if [ "$split" == "train" ]; then
|
||||
echo "$json_array" > train.json
|
||||
elif [ "$split" == "test" ]; then
|
||||
echo "$json_array" > test.json
|
||||
else
|
||||
echo "Invalid split: $split. Must be 'train' or 'test'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Processing complete."
|
||||
+257
@@ -0,0 +1,257 @@
|
||||
import cv2
|
||||
import os
|
||||
# import dlib
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import uuid
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
def get_largest_integer_filename(folder_path):
|
||||
# Check if the folder exists
|
||||
if not os.path.isdir(folder_path):
|
||||
return -1
|
||||
|
||||
# Get the list of files in the folder
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# Check if the folder is empty
|
||||
if not files:
|
||||
return -1
|
||||
|
||||
# Extract the integer part of filenames and find the largest
|
||||
largest_integer = -1
|
||||
for file in files:
|
||||
try:
|
||||
# Get the integer part of the filename
|
||||
file_int = int(os.path.splitext(file)[0])
|
||||
if file_int > largest_integer:
|
||||
largest_integer = file_int
|
||||
except ValueError:
|
||||
# Skip files that don't have an integer filename
|
||||
continue
|
||||
|
||||
return largest_integer
|
||||
|
||||
def datagen(whisper_chunks,
|
||||
crop_images,
|
||||
batch_size=8,
|
||||
delay_frame=0):
|
||||
whisper_batch, crop_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(crop_images)
|
||||
crop_image = crop_images[idx]
|
||||
whisper_batch.append(w)
|
||||
crop_batch.append(crop_image)
|
||||
|
||||
if len(crop_batch) >= batch_size:
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, crop_batch
|
||||
whisper_batch, crop_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(crop_batch) > 0:
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch, crop_batch
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
|
||||
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
|
||||
temp_audio_index=total_audio_index
|
||||
temp_image_index=total_image_index
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
folder_name = args.folder_name
|
||||
if not os.path.exists(f"data/images/{folder_name}/"):
|
||||
os.makedirs(f"data/images/{folder_name}")
|
||||
if not os.path.exists(f"data/audios/{folder_name}/"):
|
||||
os.makedirs(f"data/audios/{folder_name}")
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
|
||||
# Combine two consecutive chunks
|
||||
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
|
||||
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
|
||||
# Save the pair to a .npy file
|
||||
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
|
||||
temp_audio_index=(__//2)+total_audio_index+1
|
||||
total_audio_index=temp_audio_index
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
gc.collect()
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
crop_data=[]
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
x1=max(0,x1)
|
||||
y1=max(0,y1)
|
||||
x2=max(0,x2)
|
||||
y2=max(0,y2)
|
||||
|
||||
if ((y2-y1)<=0) or ((x2-x1)<=0):
|
||||
continue
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
crop_data.append(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
crop_data = crop_data + crop_data[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,crop_data,batch_size)
|
||||
crop_index = 0
|
||||
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
for image,audio in zip(crop_batch,whisper_batch):
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(crop_index+total_image_index+1)}.png",image)
|
||||
temp_image_index = crop_index + total_image_index + 1
|
||||
crop_index += 1
|
||||
|
||||
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
|
||||
total_image_index=temp_image_index
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
def process_audio(audio_path):
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
np.save('audio/your_filename.npy', whisper_feature)
|
||||
|
||||
def mask_face(image):
|
||||
# Load dlib's face detector and the landmark predictor
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
|
||||
predictor = dlib.shape_predictor(predictor_path)
|
||||
|
||||
# Load your input image
|
||||
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
|
||||
# image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError("Image not found or unable to load.")
|
||||
|
||||
# Convert to grayscale for detection
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Detect faces in the image
|
||||
faces = detector(gray)
|
||||
|
||||
# Process each detected face
|
||||
for face in faces:
|
||||
# Predict landmarks
|
||||
landmarks = predictor(gray, face)
|
||||
|
||||
# The indices of nose landmarks are 27 to 35
|
||||
nose_tip = landmarks.part(33).y
|
||||
|
||||
# Blacken the region below the nose tip
|
||||
blacken_area = image[nose_tip:, :]
|
||||
blacken_area[:] = (0, 0, 0)
|
||||
|
||||
# Save the final image or display it
|
||||
# cv2.imwrite("output_image.jpg", image)
|
||||
return image
|
||||
@@ -0,0 +1,182 @@
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
unet = accelerator.prepare(
|
||||
unet,
|
||||
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if not (args.unet_checkpoint == None):
|
||||
print("unet ckpt loaded")
|
||||
accelerator.load_state(args.unet_checkpoint)
|
||||
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
# shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,235 @@
|
||||
import os, random, cv2, argparse
|
||||
import torch
|
||||
from torch.utils import data as data_utils
|
||||
from os.path import dirname, join, basename, isfile
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from utils.utils import prepare_mask_and_masked_image
|
||||
import torchvision.utils as vutils
|
||||
import torchvision.transforms as transforms
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import heapq
|
||||
|
||||
syncnet_T = 1
|
||||
RESIZED_IMG = 256
|
||||
|
||||
connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线
|
||||
(17, 18), (18, 19), (19, 20), (20, 21), #左眉毛
|
||||
(22, 23), (23, 24), (24, 25), (25, 26), #右眉毛
|
||||
(27, 28),(28,29),(29,30),# 鼻梁
|
||||
(31,32),(32,33),(33,34),(34,35), #鼻子
|
||||
(36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼
|
||||
(42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼
|
||||
(48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延
|
||||
(54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延
|
||||
(60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈
|
||||
]
|
||||
|
||||
|
||||
def get_image_list(data_root, split):
|
||||
filelist = []
|
||||
imgNumList = []
|
||||
with open('filelists/{}.txt'.format(split)) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if ' ' in line:
|
||||
filename = line.split()[0]
|
||||
imgNum = int(line.split()[1])
|
||||
filelist.append(os.path.join(data_root, filename))
|
||||
imgNumList.append(imgNum)
|
||||
return filelist, imgNumList
|
||||
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
json_path,
|
||||
use_audio_length_left=1,
|
||||
use_audio_length_right=1,
|
||||
whisper_model_type = "tiny"
|
||||
):
|
||||
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
# self.split = split
|
||||
self.img_names_path = '../data'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
self.use_audio_length_right = use_audio_length_right
|
||||
|
||||
if self.whisper_model_type =="tiny":
|
||||
self.whisper_path = '../data/audios'
|
||||
self.whisper_feature_W = 5
|
||||
self.whisper_feature_H = 384
|
||||
elif self.whisper_model_type =="largeV2":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_feature_W = 33
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
with open(json_path, 'r') as file:
|
||||
self.all_videos = json.load(file)
|
||||
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
if not os.path.exists(json_path_names):
|
||||
img_names = glob(join(vidname, '*.png'))
|
||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||
with open(json_path_names, "w") as f:
|
||||
json.dump(img_names,f)
|
||||
else:
|
||||
with open(json_path_names, "r") as f:
|
||||
img_names = json.load(f)
|
||||
self.all_img_names.append(img_names)
|
||||
|
||||
def get_frame_id(self, frame):
|
||||
return int(basename(frame).split('.')[0])
|
||||
|
||||
def get_window(self, start_frame):
|
||||
start_id = self.get_frame_id(start_frame)
|
||||
vidname = dirname(start_frame)
|
||||
|
||||
window_fnames = []
|
||||
for frame_id in range(start_id, start_id + syncnet_T):
|
||||
frame = join(vidname, '{}.png'.format(frame_id))
|
||||
if not isfile(frame):
|
||||
return None
|
||||
window_fnames.append(frame)
|
||||
return window_fnames
|
||||
|
||||
def read_window(self, window_fnames):
|
||||
if window_fnames is None: return None
|
||||
window = []
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
if img is None:
|
||||
return None
|
||||
try:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG))
|
||||
except Exception as e:
|
||||
print("read_window has error fname not exist:",fname)
|
||||
return None
|
||||
|
||||
window.append(img)
|
||||
|
||||
return window
|
||||
|
||||
def prepare_window(self, window):
|
||||
# 1 x H x W x 3
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_videos)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
while 1:
|
||||
idx = random.randint(0, len(self.all_videos) - 1)
|
||||
#随机选择某个video里
|
||||
vidname = self.all_videos[idx].split('/')[-1]
|
||||
video_imgs = self.all_img_names[idx]
|
||||
if len(video_imgs) <= 11:
|
||||
continue
|
||||
img_name = random.choice(video_imgs)
|
||||
img_idx = int(basename(img_name).split(".")[0])
|
||||
random_element = random.randint(0,len(video_imgs)-1)
|
||||
while abs(random_element - img_idx) <= 5:
|
||||
random_element = random.randint(0,len(video_imgs)-1)
|
||||
img_dir = os.path.dirname(img_name)
|
||||
ref_image = os.path.join(img_dir, f"{str(random_element)}.png")
|
||||
target_window_fnames = self.get_window(img_name)
|
||||
ref_window_fnames = self.get_window(ref_image)
|
||||
|
||||
if target_window_fnames is None or ref_window_fnames is None:
|
||||
print("No such img",img_name, ref_image)
|
||||
continue
|
||||
|
||||
try:
|
||||
#构建目标img数据
|
||||
target_window = self.read_window(target_window_fnames)
|
||||
if target_window is None :
|
||||
print("No such target window,",target_window_fnames)
|
||||
continue
|
||||
#构建参考img数据
|
||||
ref_window = self.read_window(ref_window_fnames)
|
||||
|
||||
if ref_window is None:
|
||||
print("No such target ref window,",ref_window)
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"发生未知错误:{e}")
|
||||
continue
|
||||
|
||||
#构建target输入
|
||||
target_window = self.prepare_window(target_window)
|
||||
image = gt = target_window.copy().squeeze()
|
||||
target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1:输入
|
||||
ref_image = self.prepare_window(ref_window).squeeze()
|
||||
|
||||
|
||||
|
||||
mask = torch.zeros((ref_image.shape[1], ref_image.shape[2]))
|
||||
mask[:ref_image.shape[2]//2,:] = 1
|
||||
image = torch.FloatTensor(image)
|
||||
mask, masked_image = prepare_mask_and_masked_image(image,mask)
|
||||
|
||||
|
||||
|
||||
#音频特征
|
||||
window_index = self.get_frame_id(img_name)
|
||||
sub_folder_name = vidname.split('/')[-1]
|
||||
|
||||
## 根据window_index加载相邻的音频
|
||||
audio_feature_all = []
|
||||
is_index_out_of_range = False
|
||||
if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)):
|
||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||
# 判定是否越界
|
||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||
if not os.path.exists(audio_feat_path):
|
||||
is_index_out_of_range = True
|
||||
break
|
||||
|
||||
try:
|
||||
audio_feature_all.append(np.load(audio_feat_path))
|
||||
except Exception as e:
|
||||
print(f"发生未知错误:{e}")
|
||||
print(f"npy load error {audio_feat_path}")
|
||||
if is_index_out_of_range:
|
||||
continue
|
||||
audio_feature = np.concatenate(audio_feature_all, axis=0)
|
||||
else:
|
||||
continue
|
||||
|
||||
audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1, -1, 384
|
||||
if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384
|
||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||
continue
|
||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||
return ref_image, image, masked_image, mask, audio_feature
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_root = '...'
|
||||
val_data = Dataset(data_root,
|
||||
'val',
|
||||
use_audio_length_left = 2,
|
||||
use_audio_length_right = 2,
|
||||
whisper_model_type = "tiny"
|
||||
)
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_data, batch_size=4, shuffle=True,
|
||||
num_workers=1)
|
||||
|
||||
for i, data in enumerate(val_data_loader):
|
||||
ref_image, image, masked_image, mask, audio_feature = data
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Data preprocessing
|
||||
|
||||
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
|
||||
The train yaml file should contain the training video paths and corresponding audio paths
|
||||
The test yaml file should contain the validation video paths and corresponding audio paths
|
||||
|
||||
Run:
|
||||
```
|
||||
./data_new.sh train output train_video1.mp4 train_video2.mp4
|
||||
./data_new.sh test output test_video1.mp4 test_video2.mp4
|
||||
```
|
||||
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
|
||||
|
||||
## Data organization
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
├── audios
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
```
|
||||
|
||||
## Training
|
||||
Simply run after preparing the preprocessed data
|
||||
```
|
||||
cd train_codes
|
||||
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
|
||||
#--val_json="../val.json" \
|
||||
```
|
||||
## Inference with trained checkpoit
|
||||
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||
```
|
||||
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [x] release data preprocessing codes
|
||||
- [ ] release some novel designs in training (after technical report)
|
||||
@@ -0,0 +1,322 @@
|
||||
RD_Radio10_000 3501
|
||||
RD_Radio11_000 752
|
||||
RD_Radio11_001 1502
|
||||
RD_Radio12_000 876
|
||||
RD_Radio13_000 877
|
||||
RD_Radio14_000 1251
|
||||
RD_Radio16_000 2377
|
||||
RD_Radio17_000 2176
|
||||
RD_Radio18_000 1827
|
||||
RD_Radio19_000 1876
|
||||
RD_Radio1_000 1876
|
||||
RD_Radio20_000 276
|
||||
RD_Radio21_000 1201
|
||||
RD_Radio22_000 752
|
||||
RD_Radio23_000 1602
|
||||
RD_Radio25_000 2126
|
||||
RD_Radio26_000 1052
|
||||
RD_Radio27_000 1376
|
||||
RD_Radio28_000 2252
|
||||
RD_Radio29_000 2252
|
||||
RD_Radio2_000 2076
|
||||
RD_Radio30_000 2877
|
||||
RD_Radio31_000 1377
|
||||
RD_Radio32_000 1502
|
||||
RD_Radio33_000 2252
|
||||
RD_Radio34_000 702
|
||||
RD_Radio34_001 252
|
||||
RD_Radio34_002 501
|
||||
RD_Radio34_003 326
|
||||
RD_Radio34_004 502
|
||||
RD_Radio34_005 502
|
||||
RD_Radio34_006 502
|
||||
RD_Radio34_007 252
|
||||
RD_Radio34_008 876
|
||||
RD_Radio34_009 752
|
||||
RD_Radio35_000 2127
|
||||
RD_Radio36_000 2377
|
||||
RD_Radio37_000 3127
|
||||
RD_Radio38_000 3752
|
||||
RD_Radio39_000 1502
|
||||
RD_Radio3_000 2377
|
||||
RD_Radio40_000 1252
|
||||
RD_Radio41_000 2127
|
||||
RD_Radio42_000 1752
|
||||
RD_Radio43_000 1502
|
||||
RD_Radio44_000 1127
|
||||
RD_Radio45_000 1628
|
||||
RD_Radio46_000 1877
|
||||
RD_Radio47_000 1502
|
||||
RD_Radio48_000 2127
|
||||
RD_Radio49_000 1502
|
||||
RD_Radio4_000 1002
|
||||
RD_Radio50_000 2252
|
||||
RD_Radio51_000 3253
|
||||
RD_Radio52_000 2752
|
||||
RD_Radio53_000 2627
|
||||
RD_Radio54_000 577
|
||||
RD_Radio56_000 1952
|
||||
RD_Radio57_000 3077
|
||||
RD_Radio59_000 1952
|
||||
RD_Radio5_000 1377
|
||||
RD_Radio7_000 2252
|
||||
RD_Radio8_000 2126
|
||||
RD_Radio9_000 1502
|
||||
WDA_AdamSchiff_000 6877
|
||||
WDA_AdamSmith_000 7327
|
||||
WDA_AlexandriaOcasioCortez_000 2252
|
||||
WDA_AmyKlobuchar0_000 6501
|
||||
WDA_AmyKlobuchar1_000 3253
|
||||
WDA_AmyKlobuchar1_001 877
|
||||
WDA_AmyKlobuchar1_002 1502
|
||||
WDA_AmyKlobuchar1_003 1377
|
||||
WDA_AndyKim_000 6952
|
||||
WDA_AndyLevin_000 4876
|
||||
WDA_AnnieKuster_000 4876
|
||||
WDA_BarackObama_000 3575
|
||||
WDA_BarackObama_001 5625
|
||||
WDA_BarbaraLee0_000 6502
|
||||
WDA_BarbaraLee1_000 4702
|
||||
WDA_BenCardin0_000 8127
|
||||
WDA_BenCardin1_000 7102
|
||||
WDA_BenRayLujn_000 7127
|
||||
WDA_BennieThompson1_000 6375
|
||||
WDA_BennieThompson_000 5375
|
||||
WDA_BernieSanders_000 8451
|
||||
WDA_BettyMcCollum_000 7002
|
||||
WDA_BillPascrell_000 7252
|
||||
WDA_BillRichardson_000 3127
|
||||
WDA_BobCasey0_000 10627
|
||||
WDA_BobCasey1_000 252
|
||||
WDA_BobMenendez_000 3002
|
||||
WDA_BobbyScott_000 2002
|
||||
WDA_BradSchneider_000 6627
|
||||
WDA_BrendaLawrence_000 5627
|
||||
WDA_BrianSchatz0_000 502
|
||||
WDA_BrianSchatz1_000 2627
|
||||
WDA_BrianSchatz2_000 1952
|
||||
WDA_ByronDorgan1_000 4277
|
||||
WDA_CarolynMaloney1_000 8377
|
||||
WDA_CatherineCortezMasto_000 2827
|
||||
WDA_CedricRichmond_000 7250
|
||||
WDA_ChrisCoons1_000 4452
|
||||
WDA_ChrisCoons_000 6827
|
||||
WDA_ChrisMurphy0_000 6877
|
||||
WDA_ChrisMurphy1_000 6377
|
||||
WDA_ChrisVanHollen0_000 8202
|
||||
WDA_ChrisVanHollen1_000 7327
|
||||
WDA_ChuckSchumer0_000 5577
|
||||
WDA_ChuckSchumer1_000 4827
|
||||
WDA_ColinAllred_000 4751
|
||||
WDA_DanKildee1_000 6252
|
||||
WDA_DanKildee_000 2502
|
||||
WDA_DavidCicilline_000 5875
|
||||
WDA_DebHaaland_000 5876
|
||||
WDA_DebbieDingell0_000 7752
|
||||
WDA_DebbieDingell1_000 7251
|
||||
WDA_DebbieStabenow0_000 4577
|
||||
WDA_DebbieStabenow1_000 5201
|
||||
WDA_DebbieWassermanSchultz_000 7625
|
||||
WDA_DianaDeGette0_000 6002
|
||||
WDA_DianaDeGette1_000 2202
|
||||
WDA_DianneFeinstein_000 7453
|
||||
WDA_DickDurbin_000 6326
|
||||
WDA_DonaldMcEachin_000 7627
|
||||
WDA_DonnaShalala1_000 7500
|
||||
WDA_DougJones_000 6077
|
||||
WDA_EdMarkey0_000 4752
|
||||
WDA_EdMarkey1_000 6501
|
||||
WDA_ElijahCummings_000 6377
|
||||
WDA_EliotEngel_000 6876
|
||||
WDA_EmanuelCleaver_000 7001
|
||||
WDA_EricSwalwell_000 6377
|
||||
WDA_FrankPallone0_000 5625
|
||||
WDA_FrankPallone1_000 6502
|
||||
WDA_GerryConnolly_000 6752
|
||||
WDA_HakeemJeffries_000 6002
|
||||
WDA_HaleyStevens_000 5001
|
||||
WDA_HenryWaxman_000 1125
|
||||
WDA_HillaryClinton_000 2500
|
||||
WDA_JackReed0_000 2377
|
||||
WDA_JackReed1_000 4877
|
||||
WDA_JackieSpeier_000 4625
|
||||
WDA_JackyRosen1_000 10077
|
||||
WDA_JackyRosen_000 5502
|
||||
WDA_JamesClyburn1_000 6876
|
||||
WDA_JamesClyburn_000 6875
|
||||
WDA_JanSchakowsky0_000 4128
|
||||
WDA_JanSchakowsky1_000 6251
|
||||
WDA_JeanneShaheen0_000 6702
|
||||
WDA_JeanneShaheen1_000 6577
|
||||
WDA_JeffMerkley1_000 6952
|
||||
WDA_JerryNadler_000 8377
|
||||
WDA_JimHimes_000 7000
|
||||
WDA_JimmyGomez_000 4627
|
||||
WDA_JoaquinCastro_000 5126
|
||||
WDA_JoeCrowley0_000 4877
|
||||
WDA_JoeCrowley1_000 1127
|
||||
WDA_JoeCrowley1_001 627
|
||||
WDA_JoeCrowley1_002 502
|
||||
WDA_JoeCrowley1_003 752
|
||||
WDA_JoeDonnelly_000 1377
|
||||
WDA_JoeKennedy_000 1327
|
||||
WDA_JoeManchin_000 3377
|
||||
WDA_JoeNeguse_000 1628
|
||||
WDA_JoeNeguse_001 1126
|
||||
WDA_JoeNeguse_002 1251
|
||||
WDA_JohnLewis0_000 6252
|
||||
WDA_JohnLewis1_000 7252
|
||||
WDA_JohnSarbanes0_000 6377
|
||||
WDA_JohnSarbanes1_000 1877
|
||||
WDA_JohnYarmuth1_000 5377
|
||||
WDA_JonTester0_000 3252
|
||||
WDA_JonTester1_000 5327
|
||||
WDA_KarenBass_000 7126
|
||||
WDA_KatherineClark_000 1552
|
||||
WDA_KathyCastor1_000 6001
|
||||
WDA_KathyCastor_000 2252
|
||||
WDA_KimSchrier_000 6877
|
||||
WDA_KirstenGillibrand_000 9627
|
||||
WDA_LaurenUnderwood_000 9877
|
||||
WDA_LisaBluntRochester_000 4500
|
||||
WDA_LloydDoggett0_000 7377
|
||||
WDA_LloydDoggett1_000 2252
|
||||
WDA_LucilleRoybal-Allard_000 2877
|
||||
WDA_LucyMcBath_000 4000
|
||||
WDA_MarciaFudge_000 7377
|
||||
WDA_MarkWarner1_000 377
|
||||
WDA_MarkWarner1_001 1327
|
||||
WDA_MarkWarner2_000 377
|
||||
WDA_MarkWarner_000 3127
|
||||
WDA_MartinHeinrich_000 5202
|
||||
WDA_MattCartwright_000 5377
|
||||
WDA_MazieHirono0_000 3752
|
||||
WDA_MichelleLujanGrisham_000 6875
|
||||
WDA_MichelleObama_000 2000
|
||||
WDA_MikeDoyle_000 8750
|
||||
WDA_MikeThompson0_000 4625
|
||||
WDA_MikeThompson1_000 1827
|
||||
WDA_NancyPelosi0_000 10251
|
||||
WDA_NancyPelosi1_000 1377
|
||||
WDA_NancyPelosi3_000 1127
|
||||
WDA_NitaLowey_000 5876
|
||||
WDA_NydiaVelzquez_000 5500
|
||||
WDA_PatrickLeahy0_000 6951
|
||||
WDA_PatrickLeahy1_000 9953
|
||||
WDA_PattyMurray0_000 3627
|
||||
WDA_PattyMurray1_000 5702
|
||||
WDA_PeterDeFazio_000 6628
|
||||
WDA_RaulRuiz_000 5752
|
||||
WDA_RichardBlumenthal_000 7827
|
||||
WDA_RichardNeal0_000 6252
|
||||
WDA_RichardNeal1_000 6877
|
||||
WDA_RobinKelly_000 4377
|
||||
WDA_RonWyden0_000 5152
|
||||
WDA_RonWyden1_000 8327
|
||||
WDA_ScottPeters0_000 7002
|
||||
WDA_ScottPeters1_000 3952
|
||||
WDA_SeanCasten_000 6126
|
||||
WDA_SeanPatrickMaloney_000 6502
|
||||
WDA_SheldonWhitehouse0_000 6327
|
||||
WDA_SheldonWhitehouse1_000 5702
|
||||
WDA_SherrodBrown0_000 7452
|
||||
WDA_SherrodBrown1_000 7327
|
||||
WDA_StenyHoyer_000 3502
|
||||
WDA_StephanieMurphy_000 4000
|
||||
WDA_SuzanDelBene_000 6875
|
||||
WDA_TammyBaldwin0_000 5327
|
||||
WDA_TammyBaldwin1_000 2503
|
||||
WDA_TammyDuckworth_000 5702
|
||||
WDA_TedLieu_000 6877
|
||||
WDA_TerriSewell0_000 2127
|
||||
WDA_TerriSewell1_000 10952
|
||||
WDA_TerriSewell_000 6752
|
||||
WDA_TimWalz_000 6628
|
||||
WDA_TinaSmith_000 5576
|
||||
WDA_TomCarper_000 4701
|
||||
WDA_TomPerez_000 5500
|
||||
WDA_TomUdall_000 4077
|
||||
WDA_VeronicaEscobar0_000 4377
|
||||
WDA_VeronicaEscobar1_000 2453
|
||||
WDA_WhipJimClyburn_000 6876
|
||||
WDA_XavierBecerra_000 877
|
||||
WDA_XavierBecerra_001 627
|
||||
WDA_XavierBecerra_002 1377
|
||||
WDA_ZoeLofgren_000 4625
|
||||
WRA_AdamKinzinger0_000 4251
|
||||
WRA_AdamKinzinger1_000 2751
|
||||
WRA_AdamKinzinger2_000 1002
|
||||
WRA_AdamKinzinger2_001 1001
|
||||
WRA_AdamKinzinger2_002 502
|
||||
WRA_AllenWest_000 6876
|
||||
WRA_AnnMarieBuerkle_000 452
|
||||
WRA_AnnWagner_000 3127
|
||||
WRA_AustinScott0_000 1177
|
||||
WRA_BillCassidy0_000 427
|
||||
WRA_BobCorker_000 1127
|
||||
WRA_BobGoodlatte0_000 1001
|
||||
WRA_BobGoodlatte0_001 752
|
||||
WRA_BobGoodlatte0_002 876
|
||||
WRA_BobGoodlatte0_003 1002
|
||||
WRA_BobbySchilling_001 1000
|
||||
WRA_BobbySchilling_002 1001
|
||||
WRA_CandiceMiller0_000 3001
|
||||
WRA_CarlyFiorina0_000 876
|
||||
WRA_CarlyFiorina_000 1902
|
||||
WRA_CathyMcMorrisRodgers0_000 6077
|
||||
WRA_CathyMcMorrisRodgers1_000 7375
|
||||
WRA_CathyMcMorrisRodgers1_001 7500
|
||||
WRA_CathyMcMorrisRodgers1_002 7500
|
||||
WRA_CathyMcMorrisRodgers2_000 2751
|
||||
WRA_ChuckGrassley_000 4502
|
||||
WRA_CoryGardner0_000 8577
|
||||
WRA_CoryGardner1_000 2002
|
||||
WRA_CoryGardner_000 1252
|
||||
WRA_DanNewhouse_000 3627
|
||||
WRA_DanSullivan_000 6877
|
||||
WRA_DaveCamp_000 752
|
||||
WRA_DaveCamp_001 876
|
||||
WRA_DaveCamp_002 627
|
||||
WRA_DavidVitter_000 2877
|
||||
WRA_DeanHeller_000 377
|
||||
WRA_DebFischer0_000 7577
|
||||
WRA_DebFischer1_000 1577
|
||||
WRA_DebFischer2_000 5202
|
||||
WRA_DianeBlack0_000 451
|
||||
WRA_DianeBlack0_001 301
|
||||
WRA_DianeBlack1_000 250
|
||||
WRA_DuncanHunter_000 727
|
||||
WRA_EricCantor_000 2952
|
||||
WRA_ErikPaulsen_000 876
|
||||
WRA_ErikPaulsen_001 377
|
||||
WRA_ErikPaulsen_002 627
|
||||
WRA_ErikPaulsen_003 752
|
||||
WRA_FredUpton_000 1701
|
||||
WRA_GeoffDavis_000 250
|
||||
WRA_GeorgeLeMieux_000 752
|
||||
WRA_GeorgeLeMieux_001 1001
|
||||
WRA_GregWalden1_000 752
|
||||
WRA_GregWalden1_001 1127
|
||||
WRA_GregWalden1_002 1327
|
||||
WRA_GregWalden_000 377
|
||||
WRA_JaimeHerreraBeutler0_000 952
|
||||
WRA_JebHensarling0_001 1176
|
||||
WRA_JebHensarling1_000 2327
|
||||
WRA_JebHensarling1_001 727
|
||||
WRA_JebHensarling2_000 2302
|
||||
WRA_JebHensarling2_001 877
|
||||
WRA_JebHensarling2_002 2127
|
||||
WRA_JebHensarling2_003 1877
|
||||
WRA_JeffFlake_000 7202
|
||||
WRA_JeffFlake_001 7502
|
||||
WRA_JeffFlake_002 7377
|
||||
WRA_JimInhofe_000 3127
|
||||
WRA_JimRisch_000 1377
|
||||
WRA_JoeHeck1_000 250
|
||||
WRA_JoePitts_000 1077
|
||||
WRA_JohnBarrasso0_000 5077
|
||||
WRA_JohnBarrasso1_000 3452
|
||||
WRA_JohnBoehner0_000 1626
|
||||
WRA_JohnBoehner1_000 2951
|
||||
WRA_JohnBoozman_000 1877
|
||||
WRA_JohnHoeven_000 2577
|
||||
@@ -0,0 +1,30 @@
|
||||
WRA_JohnKasich0_000 1302
|
||||
WRA_JohnKasich1_000 1301
|
||||
WRA_JohnKasich1_001 1127
|
||||
WRA_JohnKasich3_000 1052
|
||||
WRA_JohnThune_000 1452
|
||||
WRA_JohnnyIsakson_000 5951
|
||||
WRA_JohnnyIsakson_001 5000
|
||||
WRA_JonKyl_000 626
|
||||
WRA_JoniErnst0_000 2077
|
||||
WRA_JoniErnst1_000 1452
|
||||
WRA_JuddGregg_000 1252
|
||||
WRA_JuddGregg_001 952
|
||||
WRA_JuddGregg_002 753
|
||||
WRA_KayBaileyHutchison_000 6325
|
||||
WRA_KellyAyotte_000 8077
|
||||
WRA_KevinBrady2_000 1276
|
||||
WRA_KevinBrady3_000 752
|
||||
WRA_KevinBrady_000 2503
|
||||
WRA_KevinMcCarthy0_000 1002
|
||||
WRA_KevinMcCarthy0_001 1127
|
||||
WRA_KristiNoem0_000 727
|
||||
WRA_KristiNoem1_000 452
|
||||
WRA_KristiNoem2_000 5952
|
||||
WRA_KristiNoem2_001 7277
|
||||
WRA_LamarAlexander0_000 1527
|
||||
WRA_LamarAlexander_000 1552
|
||||
WRA_LisaMurkowski0_000 1752
|
||||
WRA_LynnJenkins_000 877
|
||||
WRA_LynnJenkins_001 1002
|
||||
WRA_MarcoRubio_000 526
|
||||
@@ -0,0 +1,36 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.6.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"cross_attention_dim": 384,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 8,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 4,
|
||||
"sample_size": 64,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
]
|
||||
}
|
||||
Executable
+670
@@ -0,0 +1,670 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from PIL import Image, ImageDraw
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
|
||||
from DataLoader import Dataset
|
||||
from utils.utils import preprocess_img_tensor
|
||||
from torch.utils import data as data_utils
|
||||
from utils.model_utils import validation,PositionalEncoding
|
||||
import time
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--unet_config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="the configuration of unet file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reconstruction",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument("--testing_speed", action="store_true", help="Whether to caculate the running time")
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
|
||||
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Conduct validation every X updates."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_out_dir",
|
||||
type=str,
|
||||
default = '',
|
||||
help=(
|
||||
"Conduct validation every X updates."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_audio_length_left",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of audio length (left).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_audio_length_right",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of audio length (right)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper_model_type",
|
||||
type=str,
|
||||
default="landmark_nearest",
|
||||
choices=["tiny","largeV2"],
|
||||
help="Determine whisper feature type",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
def print_model_dtypes(model, model_name):
|
||||
for name, param in model.named_parameters():
|
||||
if(param.dtype!=torch.float32):
|
||||
print(f"{name}: {param.dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
args.output_dir = f"output/{args.output_dir}"
|
||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
os.makedirs(args.val_out_dir, exist_ok=True)
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
|
||||
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
project_config=project_config,
|
||||
)
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
# set_seed(args.seed)
|
||||
set_seed(seed + accelerator.process_index)
|
||||
|
||||
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
with open(args.unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
|
||||
#text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
|
||||
# Todo:
|
||||
print("Loading AutoencoderKL")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
|
||||
vae_fp32 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
print("Loading UNet2DConditionModel")
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if args.whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif args.whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
else:
|
||||
print(f"not support whisper_model_type {args.whisper_model_type}")
|
||||
|
||||
print("Loading models done...")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters()))
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
print("loading train_dataset ...")
|
||||
train_dataset = Dataset(args.data_root,
|
||||
args.train_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
train_data_loader = data_utils.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||
num_workers=8)
|
||||
print("loading val_dataset ...")
|
||||
val_dataset = Dataset(args.data_root,
|
||||
args.val_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_dataset, batch_size=1, shuffle=False,
|
||||
num_workers=8)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
unet, optimizer, train_data_loader, val_data_loader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_data_loader, val_data_loader,lr_scheduler
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
vae_fp32.requires_grad_(False)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
# weight_dtype = torch.float16
|
||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||
vae_fp32.encoder = None
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.decoder = None
|
||||
pe.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
print(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
# path="../models/pytorch_model.bin"
|
||||
#TODO change path
|
||||
# path=None
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# caluate the elapsed time
|
||||
elapsed_time = []
|
||||
start = time.time()
|
||||
|
||||
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
dataloader_time = time.time() - start
|
||||
start = time.time()
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
# """
|
||||
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
# print("ref_image: ",ref_image.shape)
|
||||
# print("masks: ", masks.shape)
|
||||
# print("masked_image: ", masked_image.shape)
|
||||
# print("audio feature: ", audio_feature.shape)
|
||||
# print("image: ", image.shape)
|
||||
# """
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
img_process_time = time.time() - start
|
||||
start = time.time()
|
||||
with accelerator.accumulate(unet):
|
||||
vae = vae.half()
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
vae_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
|
||||
|
||||
bsz = latents.shape[0]
|
||||
# fix timestep for each image
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
# concatenate the latents with the mask and the masked latents
|
||||
"""
|
||||
print("=============vae latents=====".format(epoch,step))
|
||||
print("ref_latents: ",ref_latents.shape)
|
||||
print("mask: ", mask.shape)
|
||||
print("masked_latents: ", masked_latents.shape)
|
||||
"""
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
audio_feature = pe(audio_feature)
|
||||
|
||||
# Predict the noise residual
|
||||
image_pred = unet(latent_model_input,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature).sample
|
||||
|
||||
if args.reconstruction: # decode the image from the predicted latents
|
||||
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
|
||||
image_pred_img = vae_fp32.decode(image_pred_img).sample
|
||||
|
||||
# Mask the top half of the image and calculate the loss only for the lower half of the image.
|
||||
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
|
||||
image = image[:, :, image.shape[2]//2:, :]
|
||||
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images
|
||||
loss_latents = F.l1_loss(image_pred.float(), latents.float(), reduction="mean") # the loss of the latents
|
||||
|
||||
loss = 2.0*loss_lip + loss_latents # add some weight to balance the loss
|
||||
|
||||
else:
|
||||
loss = F.mse_loss(image_pred.float(), latents.float(), reduction="mean")
|
||||
#
|
||||
|
||||
unet_elapsed_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters())
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
backward_elapsed_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
if args.testing_speed is True and accelerator.is_main_process:
|
||||
elapsed_time.append(
|
||||
[dataloader_time, unet_elapsed_time, vae_time, backward_elapsed_time,img_process_time]
|
||||
)
|
||||
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
if global_step % args.validation_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
logger.info(
|
||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||
)
|
||||
print("===========start validation==========")
|
||||
# Use the helper function to check the data types for each model
|
||||
vae_new = vae.float()
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
|
||||
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
|
||||
|
||||
print(f"weight_dtype: {weight_dtype}")
|
||||
print(f"epoch type: {type(epoch)}")
|
||||
print(f"global_step type: {type(global_step)}")
|
||||
validation(
|
||||
# vae=accelerator.unwrap_model(vae),
|
||||
vae=accelerator.unwrap_model(vae_new),
|
||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet_config=unet_config,
|
||||
# weight_dtype=weight_dtype,
|
||||
weight_dtype=torch.float32,
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
val_data_loader=val_data_loader,
|
||||
output_dir = args.val_out_dir,
|
||||
whisper_model_type = args.whisper_model_type
|
||||
)
|
||||
logger.info(f"Saved samples to images/val")
|
||||
start = time.time()
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0],
|
||||
"unet": unet_elapsed_time,
|
||||
"backward": backward_elapsed_time,
|
||||
"data": dataloader_time,
|
||||
"img_process":img_process_time,
|
||||
"vae":vae_time
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.log(
|
||||
{
|
||||
"loss/step_loss": logs["loss"],
|
||||
"parameter/lr": logs["lr"],
|
||||
"time/unet_forward_time": unet_elapsed_time,
|
||||
"time/unet_backward_time": backward_elapsed_time,
|
||||
"time/data_time": dataloader_time,
|
||||
"time/img_process_time":img_process_time,
|
||||
"time/vae_time": vae_time
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,29 @@
|
||||
export VAE_MODEL="../models/sd-vae-ft-mse/"
|
||||
export DATASET="../data"
|
||||
export UNET_CONFIG="../models/musetalk/musetalk.json"
|
||||
|
||||
accelerate launch train.py \
|
||||
--mixed_precision="fp16" \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
--data_root=$DATASET \
|
||||
--train_batch_size=256 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=100000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=2000 \
|
||||
--validation_steps=2000 \
|
||||
--reconstruction \
|
||||
--resume_from_checkpoint="latest" \
|
||||
--use_audio_length_left=2 \
|
||||
--use_audio_length_right=2 \
|
||||
--whisper_model_type="tiny" \
|
||||
--train_json="../train.json" \
|
||||
--val_json="../val.json" \
|
||||
--lr_scheduler="cosine" \
|
||||
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from utils.utils import decode_latents, preprocess_img_tensor
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
import logging
|
||||
import json
|
||||
|
||||
RESIZED_IMG = 256
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
Transformer 中的位置编码(positional encoding)
|
||||
"""
|
||||
def __init__(self, d_model=384, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
b, seq_len, d_model = x.size()
|
||||
pe = self.pe[:, :seq_len, :]
|
||||
#print(b, seq_len, d_model)
|
||||
x = x + pe.to(x.device)
|
||||
return x
|
||||
|
||||
def validation(vae: torch.nn.Module,
|
||||
vae_fp32: torch.nn.Module,
|
||||
unet:torch.nn.Module,
|
||||
unet_config,
|
||||
weight_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
val_data_loader,
|
||||
output_dir,
|
||||
whisper_model_type,
|
||||
UNet2DConditionModel=UNet2DConditionModel
|
||||
):
|
||||
|
||||
# Get the validation pipeline
|
||||
unet_copy = UNet2DConditionModel(**unet_config)
|
||||
|
||||
unet_copy.load_state_dict(unet.state_dict())
|
||||
unet_copy.to(vae.device).to(dtype=weight_dtype)
|
||||
unet_copy.eval()
|
||||
|
||||
if whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
elif whisper_model_type == "tiny-conv":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
|
||||
else:
|
||||
print(f"not support whisper_model_type {whisper_model_type}")
|
||||
pe.to(vae.device, dtype=weight_dtype)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
|
||||
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||||
|
||||
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
|
||||
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
|
||||
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
|
||||
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
|
||||
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
|
||||
|
||||
val_img_dir = f"images/{output_dir}/{global_step}"
|
||||
if not os.path.exists(val_img_dir):
|
||||
os.makedirs(val_img_dir)
|
||||
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
|
||||
|
||||
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
|
||||
|
||||
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
import os
|
||||
import cv2
|
||||
from glob import glob
|
||||
|
||||
|
||||
def preprocess_img_tensor(image_tensor):
|
||||
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
|
||||
N, C, H, W = image_tensor.shape
|
||||
# 计算新的宽度和高度,使其为 32 的整数倍
|
||||
new_w = W - W % 32
|
||||
new_h = H - H % 32
|
||||
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
|
||||
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
# 对每个图像应用变换,并将结果存储在一个新的张量中
|
||||
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
|
||||
for i in range(N):
|
||||
# 使用 F.interpolate 替换 transforms.Resize
|
||||
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
|
||||
preprocessed_images[i] = transform(resized_image.squeeze(0))
|
||||
|
||||
return preprocessed_images
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
|
||||
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
|
||||
# # 对图像张量进行归一化
|
||||
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
|
||||
# # 对遮罩张量进行归一化和二值化
|
||||
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
|
||||
mask_tensor[mask_tensor < 0.5] = 0
|
||||
mask_tensor[mask_tensor >= 0.5] = 1
|
||||
# 创建遮罩后的图像
|
||||
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
|
||||
|
||||
return mask_tensor, masked_image_tensor
|
||||
|
||||
|
||||
def encode_latents(vae, image):
|
||||
# init_image = preprocess_image(image)
|
||||
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
||||
init_latents = 0.18215 * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(vae, latents, ref_images=None):
|
||||
latents = (1/ 0.18215) * latents
|
||||
image = vae.decode(latents.to(vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
if ref_images is not None:
|
||||
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
ref_images = (ref_images * 255).round().astype("uint8")
|
||||
h = image.shape[1]
|
||||
image[:, :h//2] = ref_images[:, :h//2]
|
||||
image = [Image.fromarray(im) for im in image]
|
||||
|
||||
return image[0]
|
||||
|
||||
Reference in New Issue
Block a user