Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4724ed8a78 | |||
| b3c30c3561 | |||
| 4f69a9cfdd | |||
| 5bd772d7da | |||
| 98f0e6f2b1 | |||
| 1de8261491 | |||
| b968548131 | |||
| af82f3b00f | |||
| d74c4c098b | |||
| b4a592d7f3 | |||
| 6d19f3c0c8 | |||
| 7254ca6306 | |||
| 30dcd5237f | |||
| d73daf1808 |
+3
-11
@@ -4,15 +4,7 @@
|
||||
.vscode/
|
||||
*.pyc
|
||||
.ipynb_checkpoints
|
||||
models
|
||||
results/
|
||||
models/
|
||||
**/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
dataset/
|
||||
ffmpeg*
|
||||
ffmprobe*
|
||||
ffplay*
|
||||
debug
|
||||
exp_out
|
||||
.gradio
|
||||
data/audio/*.wav
|
||||
data/video/*.mp4
|
||||
|
||||
@@ -1,31 +1,20 @@
|
||||
# MuseTalk
|
||||
|
||||
<strong>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</strong>
|
||||
|
||||
MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting
|
||||
</br>
|
||||
Yue Zhang <sup>\*</sup>,
|
||||
Zhizhou Zhong<sup>\*</sup>,
|
||||
Minhao Liu<sup>\*</sup>,
|
||||
Zhaokang Chen,
|
||||
Bin Wu<sup>†</sup>,
|
||||
Yubin Zeng,
|
||||
Chao Zhan,
|
||||
Junxin Huang,
|
||||
Yingjie He,
|
||||
Chao Zhan,
|
||||
Wenjiang Zhou
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)
|
||||
|
||||
Lyra Lab, Tencent Music Entertainment
|
||||
|
||||
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **[Technical report](https://arxiv.org/abs/2410.10122)**
|
||||
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **Project (comming soon)** **Technical report (comming soon)**
|
||||
|
||||
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
|
||||
|
||||
## 🔥 Updates
|
||||
We're excited to unveil MuseTalk 1.5.
|
||||
This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy.
|
||||
Learn more details [here](https://arxiv.org/abs/2410.10122).
|
||||
**The inference codes, training codes and model weights of MuseTalk 1.5 are all available now!** 🚀
|
||||
|
||||
# Overview
|
||||
`MuseTalk` is a real-time high quality audio-driven lip-syncing model trained in the latent space of `ft-mse-vae`, which
|
||||
|
||||
@@ -33,384 +22,23 @@ Learn more details [here](https://arxiv.org/abs/2410.10122).
|
||||
1. supports audio in various languages, such as Chinese, English, and Japanese.
|
||||
1. supports real-time inference with 30fps+ on an NVIDIA Tesla V100.
|
||||
1. supports modification of the center point of the face region proposes, which **SIGNIFICANTLY** affects generation results.
|
||||
1. checkpoint available trained on the HDTF and private dataset.
|
||||
1. checkpoint available trained on the HDTF dataset.
|
||||
1. training codes (comming soon).
|
||||
|
||||
# News
|
||||
- [04/05/2025] :mega: We are excited to announce that the training code is now open-sourced! You can now train your own MuseTalk model using our provided training scripts and configurations.
|
||||
- [03/28/2025] We are thrilled to announce the release of our 1.5 version. This version is a significant improvement over the 1.0 version, with enhanced clarity, identity consistency, and precise lip-speech synchronization. We update the [technical report](https://arxiv.org/abs/2410.10122) with more details.
|
||||
- [10/18/2024] We release the [technical report](https://arxiv.org/abs/2410.10122v2). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
|
||||
- [04/17/2024] We release a pipeline that utilizes MuseTalk for real-time inference.
|
||||
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
|
||||
- [04/02/2024] Release MuseTalk project and pretrained models.
|
||||
|
||||
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
|
||||
- [04/17/2024] :mega: We release a pipeline that utilizes MuseTalk for real-time inference.
|
||||
- [04/30/2024] We release an initial version of training codes in `train_codes`.
|
||||
|
||||
## Model
|
||||

|
||||

|
||||
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
|
||||
|
||||
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
|
||||
|
||||
## Cases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="33%">
|
||||
|
||||
### Input Video
|
||||
---
|
||||
https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ce3e850-90ac-4a31-a45f-8dfa4f2960ac
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/fa3b13a1-ae26-4d1d-899e-87435f8d22b3
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/15800692-39d1-4f4c-99f2-aef044dc3251
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a843f9c9-136d-4ed4-9303-4a7269787a60
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/6eb4e70e-9e19-48e9-85a9-bbfa589c5fcb
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.0
|
||||
---
|
||||
https://github.com/user-attachments/assets/c04f3cd5-9f77-40e9-aafd-61978380d0ef
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/2051a388-1cef-4c1d-b2a2-3c1ceee5dc99
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b5f56f71-5cdc-4e2e-a519-454242000d32
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a5843835-04ab-4c31-989f-0995cfc22f34
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3dc7f1d7-8747-4733-bbdd-97874af0c028
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3c78064e-faad-4637-83ae-28452a22b09a
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.5
|
||||
---
|
||||
https://github.com/user-attachments/assets/999a6f5b-61dd-48e1-b902-bb3f9cbc7247
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/d26a5c9a-003c-489d-a043-c9a331456e75
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/471290d7-b157-4cf6-8a6d-7e899afa302c
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ee77c4c-8c70-4add-b6db-583a12faa7dc
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/370510ea-624c-43b7-bbb0-ab5333e0fcc4
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
# TODO:
|
||||
- [x] trained models and inference codes.
|
||||
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
|
||||
- [x] codes for real-time inference.
|
||||
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
||||
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
||||
- [x] realtime inference code for 1.5 version.
|
||||
- [x] training and data preprocessing codes.
|
||||
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
|
||||
|
||||
|
||||
# Getting Started
|
||||
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
|
||||
|
||||
## Third party integration
|
||||
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
|
||||
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
|
||||
|
||||
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
|
||||
|
||||
## Installation
|
||||
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
||||
|
||||
### Build environment
|
||||
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
|
||||
|
||||
```shell
|
||||
conda create -n MuseTalk python==3.10
|
||||
conda activate MuseTalk
|
||||
```
|
||||
|
||||
### Install PyTorch 2.0.1
|
||||
Choose one of the following installation methods:
|
||||
|
||||
```shell
|
||||
# Option 1: Using pip
|
||||
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
# Option 2: Using conda
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
```
|
||||
|
||||
### Install Dependencies
|
||||
Install the remaining required packages:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Install MMLab Packages
|
||||
Install the MMLab ecosystem packages:
|
||||
|
||||
```bash
|
||||
pip install --no-cache-dir -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv==2.0.1"
|
||||
mim install "mmdet==3.1.0"
|
||||
mim install "mmpose==1.1.0"
|
||||
```
|
||||
|
||||
### Setup FFmpeg
|
||||
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
|
||||
|
||||
2. Configure FFmpeg based on your operating system:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
export FFMPEG_PATH=/path/to/ffmpeg
|
||||
# Example:
|
||||
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
||||
```
|
||||
|
||||
For Windows:
|
||||
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
|
||||
|
||||
### Download weights
|
||||
You can download weights in two ways:
|
||||
|
||||
#### Option 1: Using Download Scripts
|
||||
We provide two scripts for automatic downloading:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
sh ./download_weights.sh
|
||||
```
|
||||
|
||||
For Windows:
|
||||
```batch
|
||||
# Run the script
|
||||
download_weights.bat
|
||||
```
|
||||
|
||||
#### Option 2: Manual Download
|
||||
You can also download the weights manually from the following links:
|
||||
|
||||
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
|
||||
2. Download the weights of other components:
|
||||
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
|
||||
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
||||
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
||||
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
|
||||
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
||||
|
||||
Finally, these weights should be organized in `models` as follows:
|
||||
```
|
||||
./models/
|
||||
├── musetalk
|
||||
│ └── musetalk.json
|
||||
│ └── pytorch_model.bin
|
||||
├── musetalkV15
|
||||
│ └── musetalk.json
|
||||
│ └── unet.pth
|
||||
├── syncnet
|
||||
│ └── latentsync_syncnet.pt
|
||||
├── dwpose
|
||||
│ └── dw-ll_ucoco_384.pth
|
||||
├── face-parse-bisent
|
||||
│ ├── 79999_iter.pth
|
||||
│ └── resnet18-5c106cde.pth
|
||||
├── sd-vae
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── whisper
|
||||
├── config.json
|
||||
├── pytorch_model.bin
|
||||
└── preprocessor_config.json
|
||||
|
||||
```
|
||||
## Quickstart
|
||||
|
||||
### Inference
|
||||
We provide inference scripts for both versions of MuseTalk:
|
||||
|
||||
#### Prerequisites
|
||||
Before running inference, please ensure ffmpeg is installed and accessible:
|
||||
```bash
|
||||
# Check ffmpeg installation
|
||||
ffmpeg -version
|
||||
```
|
||||
If ffmpeg is not found, please install it first:
|
||||
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
|
||||
- Linux: `sudo apt-get install ffmpeg`
|
||||
|
||||
#### Normal Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 normal
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 normal
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
|
||||
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
|
||||
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
#### Real-time Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 realtime
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 realtime
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||
- `video_path`: Path to the input video, image file, or directory of images
|
||||
- `audio_path`: Path to the input audio file
|
||||
|
||||
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||
|
||||
Important notes for real-time inference:
|
||||
1. Set `preparation` to `True` when processing a new avatar
|
||||
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
|
||||
4. Set `preparation` to `False` for generating more videos with the same avatar
|
||||
|
||||
For faster generation without saving images, you can use:
|
||||
```bash
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
|
||||
## Gradio Demo
|
||||
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
|
||||

|
||||
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. 
|
||||
|
||||
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
|
||||
|
||||
```bash
|
||||
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
|
||||
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Data Preparation
|
||||
To train MuseTalk, you need to prepare your dataset following these steps:
|
||||
|
||||
1. **Place your source videos**
|
||||
|
||||
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
|
||||
|
||||
2. **Run the preprocessing script**
|
||||
```bash
|
||||
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
|
||||
```
|
||||
This script will:
|
||||
- Extract frames from videos
|
||||
- Detect and align faces
|
||||
- Generate audio features
|
||||
- Create the necessary data structure for training
|
||||
|
||||
### Training Process
|
||||
After data preprocessing, you can start the training process:
|
||||
|
||||
1. **First Stage**
|
||||
```bash
|
||||
sh train.sh stage1
|
||||
```
|
||||
|
||||
2. **Second Stage**
|
||||
```bash
|
||||
sh train.sh stage2
|
||||
```
|
||||
|
||||
### Configuration Adjustment
|
||||
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
|
||||
|
||||
1. **GPU Configuration** (`configs/training/gpu.yaml`):
|
||||
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
|
||||
- `num_processes`: Set this to match the number of GPUs you're using
|
||||
|
||||
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
|
||||
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
|
||||
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
|
||||
|
||||
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
|
||||
- `random_init_unet`: Must be set to `False` to use the model from stage 1
|
||||
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
|
||||
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
|
||||
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
|
||||
|
||||
|
||||
### GPU Memory Requirements
|
||||
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
|
||||
|
||||
#### Stage 1 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 8 | 1 | ~32GB | |
|
||||
| 16 | 1 | ~45GB | |
|
||||
| 32 | 1 | ~74GB | ✓ |
|
||||
|
||||
#### Stage 2 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 1 | 8 | ~54GB | |
|
||||
| 2 | 2 | ~80GB | |
|
||||
| 2 | 8 | ~85GB | ✓ |
|
||||
|
||||
<details close>
|
||||
## TestCases For 1.0
|
||||
### MuseV + MuseTalk make human photos alive!
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
<td width="33%">Image</td>
|
||||
@@ -496,7 +124,132 @@ Based on our testing on a machine with 8 NVIDIA H20 GPUs:
|
||||
</tr>
|
||||
</table >
|
||||
|
||||
#### Use of bbox_shift to have adjustable results(For 1.0)
|
||||
* The character of the last two rows, `Xinying Sun`, is a supermodel KOL. You can follow her on [douyin](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8).
|
||||
|
||||
## Video dubbing
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
<td width="70%">MuseTalk</td>
|
||||
<td width="30%">Original videos</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/4d7c5fa1-3550-4d52-8ed2-52f158150f24 controls preload></video>
|
||||
</td>
|
||||
<td>
|
||||
<a href="//www.bilibili.com/video/BV1wT411b7HU">Link</a>
|
||||
<href src=""></href>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
* For video dubbing, we applied a self-developed tool which can identify the talking person.
|
||||
|
||||
## Some interesting videos!
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
<td width="50%">Image</td>
|
||||
<td width="50%">MuseV + MuseTalk</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/video1/video1.png width="95%">
|
||||
</td>
|
||||
<td>
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/1f02f9c6-8b98-475e-86b8-82ebee82fe0d controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
# TODO:
|
||||
- [x] trained models and inference codes.
|
||||
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
|
||||
- [x] codes for real-time inference.
|
||||
- [ ] technical report.
|
||||
- [x] training codes.
|
||||
- [ ] a better model (may take longer).
|
||||
|
||||
|
||||
# Getting Started
|
||||
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
|
||||
|
||||
## Third party integration
|
||||
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
|
||||
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
|
||||
|
||||
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
|
||||
|
||||
## Installation
|
||||
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
||||
### Build environment
|
||||
|
||||
We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### mmlab packages
|
||||
```bash
|
||||
pip install --no-cache-dir -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv>=2.0.1"
|
||||
mim install "mmdet>=3.1.0"
|
||||
mim install "mmpose>=1.1.0"
|
||||
```
|
||||
|
||||
### Download ffmpeg-static
|
||||
Download the ffmpeg-static and
|
||||
```
|
||||
export FFMPEG_PATH=/path/to/ffmpeg
|
||||
```
|
||||
for example:
|
||||
```
|
||||
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
||||
```
|
||||
### Download weights
|
||||
You can download weights manually as follows:
|
||||
|
||||
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
|
||||
|
||||
2. Download the weights of other components:
|
||||
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
|
||||
- [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
|
||||
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
||||
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
||||
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
||||
|
||||
|
||||
Finally, these weights should be organized in `models` as follows:
|
||||
```
|
||||
./models/
|
||||
├── musetalk
|
||||
│ └── musetalk.json
|
||||
│ └── pytorch_model.bin
|
||||
├── dwpose
|
||||
│ └── dw-ll_ucoco_384.pth
|
||||
├── face-parse-bisent
|
||||
│ ├── 79999_iter.pth
|
||||
│ └── resnet18-5c106cde.pth
|
||||
├── sd-vae-ft-mse
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── whisper
|
||||
└── tiny.pt
|
||||
```
|
||||
## Quickstart
|
||||
|
||||
### Inference
|
||||
Here, we provide the inference script.
|
||||
```
|
||||
python -m scripts.inference --inference_config configs/inference/test.yaml
|
||||
```
|
||||
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
|
||||
The video_path should be either a video file, an image file or a directory of images.
|
||||
|
||||
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
|
||||
|
||||
#### Use of bbox_shift to have adjustable results
|
||||
:mag_right: We have found that upper-bound of the mask has an important impact on mouth openness. Thus, to control the mask region, we suggest using the `bbox_shift` parameter. Positive values (moving towards the lower half) increase mouth openness, while negative values (moving towards the upper half) decrease mouth openness.
|
||||
|
||||
You can start by running with the default configuration to obtain the adjustable value range, and then re-run the script within this range.
|
||||
@@ -507,13 +260,36 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
|
||||
```
|
||||
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
||||
|
||||
|
||||
#### Combining MuseV and MuseTalk
|
||||
|
||||
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
||||
|
||||
#### :new: Real-time inference
|
||||
|
||||
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
||||
|
||||
```
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
|
||||
```
|
||||
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
|
||||
|
||||
1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
|
||||
1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
|
||||
```
|
||||
Inferring using: data/audio/yongen.wav
|
||||
```
|
||||
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
|
||||
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
|
||||
|
||||
##### Note for Real-time inference
|
||||
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
|
||||
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
|
||||
```
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
|
||||
# Acknowledgement
|
||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
|
||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
||||
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
||||
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
|
||||
|
||||
@@ -530,10 +306,10 @@ If you need higher resolution, you could apply super resolution models such as [
|
||||
# Citation
|
||||
```bib
|
||||
@article{musetalk,
|
||||
title={MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling},
|
||||
author={Zhang, Yue and Zhong, Zhizhou and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
|
||||
title={MuseTalk: Real-Time High Quality Lip Synchorization with Latent Space Inpainting},
|
||||
author={Zhang, Yue and Liu, Minhao and Chen, Zhaokang and Wu, Bin and He, Yingjie and Zhan, Chao and Zhou, Wenjiang},
|
||||
journal={arxiv},
|
||||
year={2025}
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
# Disclaimer/License
|
||||
|
||||
@@ -4,6 +4,7 @@ import pdb
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import numpy as np
|
||||
import sys
|
||||
import subprocess
|
||||
@@ -27,101 +28,11 @@ import gdown
|
||||
import imageio
|
||||
import ffmpeg
|
||||
from moviepy.editor import *
|
||||
from transformers import WhisperModel
|
||||
|
||||
|
||||
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
||||
CheckpointsDir = os.path.join(ProjectDir, "models")
|
||||
|
||||
@torch.no_grad()
|
||||
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90):
|
||||
"""Debug inpainting parameters, only process the first frame"""
|
||||
# Set default parameters
|
||||
args_dict = {
|
||||
"result_dir": './results/debug',
|
||||
"fps": 25,
|
||||
"batch_size": 1,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15",
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Create debug directory
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
|
||||
# Read first frame
|
||||
if get_file_type(video_path) == "video":
|
||||
reader = imageio.get_reader(video_path)
|
||||
first_frame = reader.get_data(0)
|
||||
reader.close()
|
||||
else:
|
||||
first_frame = cv2.imread(video_path)
|
||||
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Save first frame
|
||||
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
|
||||
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# Get face coordinates
|
||||
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
|
||||
bbox = coord_list[0]
|
||||
frame = frame_list[0]
|
||||
|
||||
if bbox == coord_placeholder:
|
||||
return None, "No face detected, please adjust bbox_shift parameter"
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
# Process first frame
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
|
||||
# Generate random audio features
|
||||
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
|
||||
audio_feature = pe(random_audio)
|
||||
|
||||
# Get latents
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
latents = latents.to(dtype=weight_dtype)
|
||||
|
||||
# Generate prediction results
|
||||
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
|
||||
# Inpaint back to original image
|
||||
res_frame = recon[0]
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
# Save results (no need to convert color space again since get_image already returns RGB format)
|
||||
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
|
||||
cv2.imwrite(debug_result_path, combine_frame)
|
||||
|
||||
# Create information text
|
||||
info_text = f"Parameter information:\n" + \
|
||||
f"bbox_shift: {bbox_shift}\n" + \
|
||||
f"extra_margin: {extra_margin}\n" + \
|
||||
f"parsing_mode: {parsing_mode}\n" + \
|
||||
f"left_cheek_width: {left_cheek_width}\n" + \
|
||||
f"right_cheek_width: {right_cheek_width}\n" + \
|
||||
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
|
||||
|
||||
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
|
||||
|
||||
def print_directory_contents(path):
|
||||
for child in os.listdir(path):
|
||||
child_path = os.path.join(path, child)
|
||||
@@ -129,107 +40,119 @@ def print_directory_contents(path):
|
||||
print(child_path)
|
||||
|
||||
def download_model():
|
||||
# 检查必需的模型文件是否存在
|
||||
required_models = {
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
|
||||
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
|
||||
"Whisper": f"{CheckpointsDir}/whisper/config.json",
|
||||
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
|
||||
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
|
||||
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
|
||||
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||
}
|
||||
|
||||
missing_models = []
|
||||
for model_name, model_path in required_models.items():
|
||||
if not os.path.exists(model_path):
|
||||
missing_models.append(model_name)
|
||||
|
||||
if missing_models:
|
||||
# 全用英文
|
||||
print("The following required model files are missing:")
|
||||
for model in missing_models:
|
||||
print(f"- {model}")
|
||||
print("\nPlease run the download script to download the missing models:")
|
||||
if sys.platform == "win32":
|
||||
print("Windows: Run download_weights.bat")
|
||||
if not os.path.exists(CheckpointsDir):
|
||||
os.makedirs(CheckpointsDir)
|
||||
print("Checkpoint Not Downloaded, start downloading...")
|
||||
tic = time.time()
|
||||
snapshot_download(
|
||||
repo_id="TMElyralab/MuseTalk",
|
||||
local_dir=CheckpointsDir,
|
||||
max_workers=8,
|
||||
local_dir_use_symlinks=True,
|
||||
force_download=True, resume_download=False
|
||||
)
|
||||
# weight
|
||||
os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/")
|
||||
snapshot_download(
|
||||
repo_id="stabilityai/sd-vae-ft-mse",
|
||||
local_dir=CheckpointsDir+'/sd-vae-ft-mse',
|
||||
max_workers=8,
|
||||
local_dir_use_symlinks=True,
|
||||
force_download=True, resume_download=False
|
||||
)
|
||||
#dwpose
|
||||
os.makedirs(f"{CheckpointsDir}/dwpose/")
|
||||
snapshot_download(
|
||||
repo_id="yzd-v/DWPose",
|
||||
local_dir=CheckpointsDir+'/dwpose',
|
||||
max_workers=8,
|
||||
local_dir_use_symlinks=True,
|
||||
force_download=True, resume_download=False
|
||||
)
|
||||
#vae
|
||||
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
|
||||
response = requests.get(url)
|
||||
# 确保请求成功
|
||||
if response.status_code == 200:
|
||||
# 指定文件保存的位置
|
||||
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
|
||||
os.makedirs(f"{CheckpointsDir}/whisper/")
|
||||
# 将文件内容写入指定位置
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
print("Linux/Mac: Run ./download_weights.sh")
|
||||
sys.exit(1)
|
||||
print(f"请求失败,状态码:{response.status_code}")
|
||||
#gdown face parse
|
||||
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
|
||||
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
|
||||
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
|
||||
gdown.download(url, file_path, quiet=False)
|
||||
#resnet
|
||||
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
||||
response = requests.get(url)
|
||||
# 确保请求成功
|
||||
if response.status_code == 200:
|
||||
# 指定文件保存的位置
|
||||
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||
# 将文件内容写入指定位置
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
print("All required model files exist.")
|
||||
print(f"请求失败,状态码:{response.status_code}")
|
||||
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"download cost {toc-tic} seconds")
|
||||
print_directory_contents(CheckpointsDir)
|
||||
|
||||
else:
|
||||
print("Already download the model.")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
download_model() # for huggingface deployment.
|
||||
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
@spaces.GPU(duration=600)
|
||||
@torch.no_grad()
|
||||
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
|
||||
# Set default parameters, aligned with inference.py
|
||||
args_dict = {
|
||||
"result_dir": './results/output',
|
||||
"fps": 25,
|
||||
"batch_size": 8,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15", # Fixed use v15 version
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
|
||||
args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Check ffmpeg
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save path
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name=="":
|
||||
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
# Read video
|
||||
# cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
# os.system(cmd)
|
||||
# 读取视频
|
||||
reader = imageio.get_reader(video_path)
|
||||
|
||||
# Save images
|
||||
# 保存图片
|
||||
for i, im in enumerate(reader):
|
||||
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
@@ -238,21 +161,10 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
|
||||
#print(input_img_list)
|
||||
############################################## extract audio feature ##############################################
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
@@ -265,21 +177,12 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
@@ -289,23 +192,17 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
# Ensure latent_batch is consistent with model weight type
|
||||
latent_batch = latent_batch.to(dtype=weight_dtype)
|
||||
|
||||
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
||||
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
@@ -318,24 +215,25 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
# print(bbox)
|
||||
continue
|
||||
|
||||
# Use v15 version blending
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
# Frame rate
|
||||
# cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4"
|
||||
# print(cmd_img2video)
|
||||
# os.system(cmd_img2video)
|
||||
# 帧率
|
||||
fps = 25
|
||||
# Output video path
|
||||
# 图片路径
|
||||
# 输出视频路径
|
||||
output_video = 'temp.mp4'
|
||||
|
||||
# Read images
|
||||
# 读取图片
|
||||
def is_valid_image(file):
|
||||
pattern = re.compile(r'\d{8}\.png')
|
||||
return pattern.match(file)
|
||||
@@ -349,9 +247,13 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
images.append(imageio.imread(filename))
|
||||
|
||||
|
||||
# Save video
|
||||
# 保存视频
|
||||
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
input_video = './temp.mp4'
|
||||
# Check if the input_video and audio_path exist
|
||||
if not os.path.exists(input_video):
|
||||
@@ -359,15 +261,40 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Read video
|
||||
# 读取视频
|
||||
reader = imageio.get_reader(input_video)
|
||||
fps = reader.get_meta_data()['fps'] # Get original video frame rate
|
||||
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
|
||||
# Store frames in list
|
||||
fps = reader.get_meta_data()['fps'] # 获取原视频的帧率
|
||||
|
||||
# 将帧存储在列表中
|
||||
frames = images
|
||||
|
||||
# 保存视频并添加音频
|
||||
# imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
|
||||
|
||||
# input_video = ffmpeg.input(input_video)
|
||||
|
||||
# input_audio = ffmpeg.input(audio_path)
|
||||
|
||||
print(len(frames))
|
||||
|
||||
# imageio.mimwrite(
|
||||
# output_video,
|
||||
# frames,
|
||||
# 'FFMPEG',
|
||||
# fps=25,
|
||||
# codec='libx264',
|
||||
# audio_codec='aac',
|
||||
# input_params=['-i', audio_path],
|
||||
# output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
|
||||
# )
|
||||
# writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
|
||||
# for im in frames:
|
||||
# writer.append_data(im)
|
||||
# writer.close()
|
||||
|
||||
|
||||
|
||||
|
||||
# Load the video
|
||||
video_clip = VideoFileClip(input_video)
|
||||
|
||||
@@ -388,45 +315,11 @@ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode=
|
||||
|
||||
|
||||
# load model weights
|
||||
audio_processor,vae,unet,pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path="./models/musetalkV15/unet.pth",
|
||||
vae_type="sd-vae",
|
||||
unet_config="./models/musetalkV15/musetalk.json",
|
||||
device=device
|
||||
)
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
||||
parser.add_argument("--share", action="store_true", help="Create a public link")
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set data type
|
||||
if args.use_float16:
|
||||
# Convert models to half precision for better performance
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
weight_dtype = torch.float16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
||||
whisper = WhisperModel.from_pretrained("./models/whisper")
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
|
||||
|
||||
def check_video(video):
|
||||
@@ -447,6 +340,9 @@ def check_video(video):
|
||||
output_video = os.path.join('./results/input', output_file_name)
|
||||
|
||||
|
||||
# # Run the ffmpeg command to change the frame rate to 25fps
|
||||
# command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
|
||||
|
||||
# read video
|
||||
reader = imageio.get_reader(video)
|
||||
fps = reader.get_meta_data()['fps'] # get fps from original video
|
||||
@@ -478,44 +374,33 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.Markdown(
|
||||
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
|
||||
"<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
|
||||
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
||||
</br>\
|
||||
Yue Zhang <sup>*</sup>,\
|
||||
Zhizhou Zhong <sup>*</sup>,\
|
||||
Minhao Liu<sup>*</sup>,\
|
||||
Yue Zhang <sup>\*</sup>,\
|
||||
Minhao Liu<sup>\*</sup>,\
|
||||
Zhaokang Chen,\
|
||||
Bin Wu<sup>†</sup>,\
|
||||
Yubin Zeng,\
|
||||
Chao Zhang,\
|
||||
Yingjie He,\
|
||||
Junxin Huang,\
|
||||
Wenjiang Zhou <br>\
|
||||
Chao Zhan,\
|
||||
Wenjiang Zhou\
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
||||
Lyra Lab, Tencent Music Entertainment\
|
||||
</h2> \
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
|
||||
<a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
|
||||
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = gr.Audio(label="Drving Audio",type="filepath")
|
||||
audio = gr.Audio(label="Driven Audio",type="filepath")
|
||||
video = gr.Video(label="Reference Video",sources=['upload'])
|
||||
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
||||
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
|
||||
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
|
||||
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
|
||||
bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
debug_btn = gr.Button("1. Test Inpainting ")
|
||||
btn = gr.Button("2. Generate")
|
||||
with gr.Column():
|
||||
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
|
||||
debug_info = gr.Textbox(label="Parameter Information", lines=5)
|
||||
btn = gr.Button("Generate")
|
||||
out1 = gr.Video()
|
||||
|
||||
video.change(
|
||||
@@ -527,44 +412,15 @@ with gr.Blocks(css=css) as demo:
|
||||
audio,
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[out1,bbox_shift_scale]
|
||||
)
|
||||
debug_btn.click(
|
||||
fn=debug_inpainting,
|
||||
inputs=[
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[debug_image, debug_info]
|
||||
)
|
||||
|
||||
# Check ffmpeg and add to PATH
|
||||
if not fast_check_ffmpeg():
|
||||
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
|
||||
# According to operating system, choose path separator
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
# Set the IP and port
|
||||
ip_address = "0.0.0.0" # Replace with your desired IP address
|
||||
port_number = 7860 # Replace with your desired port number
|
||||
|
||||
# Solve asynchronous IO issues on Windows
|
||||
if sys.platform == 'win32':
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
# Start Gradio application
|
||||
demo.queue().launch(
|
||||
share=args.share,
|
||||
debug=True,
|
||||
server_name=args.ip,
|
||||
server_port=args.port
|
||||
share=False , debug=True, server_name=ip_address, server_port=port_number
|
||||
)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 73 KiB |
@@ -1,10 +1,10 @@
|
||||
avator_1:
|
||||
preparation: True # your can set it to False if you want to use the existing avator, it will save time
|
||||
preparation: False
|
||||
bbox_shift: 5
|
||||
video_path: "data/video/yongen.mp4"
|
||||
video_path: "data/video/sun.mp4"
|
||||
audio_clips:
|
||||
audio_0: "data/audio/yongen.wav"
|
||||
audio_1: "data/audio/eng.wav"
|
||||
audio_1: "data/audio/sun.wav"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ task_0:
|
||||
audio_path: "data/audio/yongen.wav"
|
||||
|
||||
task_1:
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_path: "data/audio/eng.wav"
|
||||
video_path: "data/video/sun.mp4"
|
||||
audio_path: "data/audio/sun.wav"
|
||||
bbox_shift: -7
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: True
|
||||
deepspeed_config:
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: False
|
||||
zero_stage: 2
|
||||
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: "5, 7" # modify this according to your GPU number
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
num_processes: 2 # it should be the same as the number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,31 +0,0 @@
|
||||
clip_len_second: 30 # the length of the video clip
|
||||
video_root_raw: "./dataset/HDTF/source/" # the path of the original video
|
||||
val_list_hdtf:
|
||||
- RD_Radio7_000
|
||||
- RD_Radio8_000
|
||||
- RD_Radio9_000
|
||||
- WDA_TinaSmith_000
|
||||
- WDA_TomCarper_000
|
||||
- WDA_TomPerez_000
|
||||
- WDA_TomUdall_000
|
||||
- WDA_VeronicaEscobar0_000
|
||||
- WDA_VeronicaEscobar1_000
|
||||
- WDA_WhipJimClyburn_000
|
||||
- WDA_XavierBecerra_000
|
||||
- WDA_XavierBecerra_001
|
||||
- WDA_XavierBecerra_002
|
||||
- WDA_ZoeLofgren_000
|
||||
- WRA_SteveScalise1_000
|
||||
- WRA_TimScott_000
|
||||
- WRA_ToddYoung_000
|
||||
- WRA_TomCotton_000
|
||||
- WRA_TomPrice_000
|
||||
- WRA_VickyHartzler_000
|
||||
|
||||
# following dir will be automatically generated
|
||||
video_root_25fps: "./dataset/HDTF/video_root_25fps/"
|
||||
video_file_list: "./dataset/HDTF/video_file_list.txt"
|
||||
video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
|
||||
meta_root: "./dataset/HDTF/meta/"
|
||||
video_clip_file_list_train: "./dataset/HDTF/train.txt"
|
||||
video_clip_file_list_val: "./dataset/HDTF/val.txt"
|
||||
@@ -1,89 +0,0 @@
|
||||
exp_name: 'test' # Name of the experiment
|
||||
output_dir: './exp_out/stage1/' # Directory to save experiment outputs
|
||||
unet_sub_folder: musetalk # Subfolder name for UNet model
|
||||
random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
||||
whisper_path: "./models/whisper" # Path to the Whisper model
|
||||
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
||||
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
||||
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
||||
vae_type: "sd-vae" # Type of VAE model to use
|
||||
# Validation parameters
|
||||
num_images_to_keep: 8 # Number of validation images to keep
|
||||
ref_dropout_rate: 0 # Dropout rate for reference images
|
||||
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
||||
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
||||
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
||||
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
||||
crop_type: "crop_resize" # Type of cropping method
|
||||
random_margin_method: "normal" # Method for random margin generation
|
||||
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
||||
|
||||
data:
|
||||
dataset_key: "HDTF" # Dataset to use for training
|
||||
train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
||||
image_size: 256 # Size of input images
|
||||
n_sample_frames: 1 # Number of frames to sample per batch
|
||||
num_workers: 8 # Number of data loading workers
|
||||
audio_padding_length_left: 2 # Left padding length for audio features
|
||||
audio_padding_length_right: 2 # Right padding length for audio features
|
||||
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
||||
top_k_ratio: 0.51 # Ratio for top-k sampling
|
||||
contorl_face_min_size: True # Whether to control minimum face size
|
||||
min_face_size: 150 # Minimum face size in pixels
|
||||
|
||||
loss_params:
|
||||
l1_loss: 1.0 # Weight for L1 loss
|
||||
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
||||
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
||||
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
||||
gan_loss: 0 # Weight for GAN loss
|
||||
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
||||
sync_loss: 0 # Weight for sync loss
|
||||
mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
|
||||
|
||||
model_params:
|
||||
discriminator_params:
|
||||
scales: [1] # Scales for discriminator
|
||||
block_expansion: 32 # Expansion factor for discriminator blocks
|
||||
max_features: 512 # Maximum number of features in discriminator
|
||||
num_blocks: 4 # Number of blocks in discriminator
|
||||
sn: True # Whether to use spectral normalization
|
||||
image_channel: 3 # Number of image channels
|
||||
estimate_jacobian: False # Whether to estimate Jacobian
|
||||
|
||||
discriminator_train_params:
|
||||
lr: 0.000005 # Learning rate for discriminator
|
||||
eps: 0.00000001 # Epsilon for optimizer
|
||||
weight_decay: 0.01 # Weight decay for optimizer
|
||||
patch_size: 1 # Size of patches for discriminator
|
||||
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
||||
epochs: 10000 # Number of training epochs
|
||||
start_gan: 1000 # Step to start GAN training
|
||||
|
||||
solver:
|
||||
gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
|
||||
uncond_steps: 10 # Number of unconditional steps
|
||||
mixed_precision: 'fp32' # Precision mode for training
|
||||
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
||||
gradient_checkpointing: True # Whether to use gradient checkpointing
|
||||
max_train_steps: 250000 # Maximum number of training steps
|
||||
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
||||
# Learning rate parameters
|
||||
learning_rate: 2.0e-5 # Base learning rate
|
||||
scale_lr: False # Whether to scale learning rate
|
||||
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
||||
lr_scheduler: "linear" # Type of learning rate scheduler
|
||||
# Optimizer parameters
|
||||
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
||||
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
||||
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
||||
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
||||
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
||||
|
||||
total_limit: 10 # Maximum number of checkpoints to keep
|
||||
save_model_epoch_interval: 250000 # Interval between model saves
|
||||
checkpointing_steps: 10000 # Number of steps between checkpoints
|
||||
val_freq: 2000 # Frequency of validation
|
||||
|
||||
seed: 41 # Random seed for reproducibility
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
exp_name: 'test' # Name of the experiment
|
||||
output_dir: './exp_out/stage2/' # Directory to save experiment outputs
|
||||
unet_sub_folder: musetalk # Subfolder name for UNet model
|
||||
random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
||||
whisper_path: "./models/whisper" # Path to the Whisper model
|
||||
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
||||
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
||||
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
||||
vae_type: "sd-vae" # Type of VAE model to use
|
||||
# Validation parameters
|
||||
num_images_to_keep: 8 # Number of validation images to keep
|
||||
ref_dropout_rate: 0 # Dropout rate for reference images
|
||||
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
||||
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
||||
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
||||
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
||||
crop_type: "dynamic_margin_crop_resize" # Type of cropping method
|
||||
random_margin_method: "normal" # Method for random margin generation
|
||||
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
||||
|
||||
data:
|
||||
dataset_key: "HDTF" # Dataset to use for training
|
||||
train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
||||
image_size: 256 # Size of input images
|
||||
n_sample_frames: 16 # Number of frames to sample per batch
|
||||
num_workers: 8 # Number of data loading workers
|
||||
audio_padding_length_left: 2 # Left padding length for audio features
|
||||
audio_padding_length_right: 2 # Right padding length for audio features
|
||||
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
||||
top_k_ratio: 0.51 # Ratio for top-k sampling
|
||||
contorl_face_min_size: True # Whether to control minimum face size
|
||||
min_face_size: 200 # Minimum face size in pixels
|
||||
|
||||
loss_params:
|
||||
l1_loss: 1.0 # Weight for L1 loss
|
||||
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
||||
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
||||
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
||||
gan_loss: 0.01 # Weight for GAN loss
|
||||
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
||||
sync_loss: 0.05 # Weight for sync loss
|
||||
mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
|
||||
|
||||
model_params:
|
||||
discriminator_params:
|
||||
scales: [1] # Scales for discriminator
|
||||
block_expansion: 32 # Expansion factor for discriminator blocks
|
||||
max_features: 512 # Maximum number of features in discriminator
|
||||
num_blocks: 4 # Number of blocks in discriminator
|
||||
sn: True # Whether to use spectral normalization
|
||||
image_channel: 3 # Number of image channels
|
||||
estimate_jacobian: False # Whether to estimate Jacobian
|
||||
|
||||
discriminator_train_params:
|
||||
lr: 0.000005 # Learning rate for discriminator
|
||||
eps: 0.00000001 # Epsilon for optimizer
|
||||
weight_decay: 0.01 # Weight decay for optimizer
|
||||
patch_size: 1 # Size of patches for discriminator
|
||||
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
||||
epochs: 10000 # Number of training epochs
|
||||
start_gan: 1000 # Step to start GAN training
|
||||
|
||||
solver:
|
||||
gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
|
||||
uncond_steps: 10 # Number of unconditional steps
|
||||
mixed_precision: 'fp32' # Precision mode for training
|
||||
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
||||
gradient_checkpointing: True # Whether to use gradient checkpointing
|
||||
max_train_steps: 250000 # Maximum number of training steps
|
||||
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
||||
# Learning rate parameters
|
||||
learning_rate: 5.0e-6 # Base learning rate
|
||||
scale_lr: False # Whether to scale learning rate
|
||||
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
||||
lr_scheduler: "linear" # Type of learning rate scheduler
|
||||
# Optimizer parameters
|
||||
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
||||
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
||||
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
||||
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
||||
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
||||
|
||||
total_limit: 10 # Maximum number of checkpoints to keep
|
||||
save_model_epoch_interval: 250000 # Interval between model saves
|
||||
checkpointing_steps: 2000 # Number of steps between checkpoints
|
||||
val_freq: 2000 # Frequency of validation
|
||||
|
||||
seed: 41 # Random seed for reproducibility
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 52)
|
||||
in_channels: 1
|
||||
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
||||
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (48, 128, 256)
|
||||
in_channels: 48
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
||||
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
save_ckpt_steps: 2500
|
||||
Binary file not shown.
Executable
+77
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Function to extract video and audio sections
|
||||
extract_sections() {
|
||||
input_video=$1
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
output_dir=$2
|
||||
split=$3
|
||||
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
|
||||
IFS=: read -r hours minutes seconds <<< "$duration"
|
||||
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
|
||||
chunk_size=180 # 3 minutes in seconds
|
||||
index=0
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
||||
while [ $((index * chunk_size)) -lt $total_seconds ]; do
|
||||
start_time=$((index * chunk_size))
|
||||
section_video="${output_dir}/${base_name}_part${index}.mp4"
|
||||
section_audio="${output_dir}/${base_name}_part${index}.mp3"
|
||||
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
|
||||
|
||||
# Create and update the config.yaml file
|
||||
echo "task_0:" > config.yaml
|
||||
echo " video_path: \"$section_video\"" >> config.yaml
|
||||
echo " audio_path: \"$section_audio\"" >> config.yaml
|
||||
|
||||
# Run the Python script with the current config.yaml
|
||||
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
|
||||
|
||||
index=$((index + 1))
|
||||
done
|
||||
|
||||
# Clean up save folder
|
||||
rm -rf $output_dir
|
||||
}
|
||||
|
||||
# Main script
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
split=$1
|
||||
output_dir=$2
|
||||
shift 2
|
||||
input_videos=("$@")
|
||||
|
||||
# Initialize JSON array
|
||||
json_array="["
|
||||
|
||||
for input_video in "${input_videos[@]}"; do
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
|
||||
# Extract sections and run the Python script for each section
|
||||
extract_sections "$input_video" "$output_dir" "$split"
|
||||
|
||||
# Add entry to JSON array
|
||||
json_array+="\"../data/images/$base_name\","
|
||||
done
|
||||
|
||||
# Remove trailing comma and close JSON array
|
||||
json_array="${json_array%,}]"
|
||||
|
||||
# Write JSON array to the correct file
|
||||
if [ "$split" == "train" ]; then
|
||||
echo "$json_array" > train.json
|
||||
elif [ "$split" == "test" ]; then
|
||||
echo "$json_array" > test.json
|
||||
else
|
||||
echo "Invalid split: $split. Must be 'train' or 'test'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Processing complete."
|
||||
@@ -1,45 +0,0 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
:: Set the checkpoints directory
|
||||
set CheckpointsDir=models
|
||||
|
||||
:: Create necessary directories
|
||||
mkdir %CheckpointsDir%\musetalk
|
||||
mkdir %CheckpointsDir%\musetalkV15
|
||||
mkdir %CheckpointsDir%\syncnet
|
||||
mkdir %CheckpointsDir%\dwpose
|
||||
mkdir %CheckpointsDir%\face-parse-bisent
|
||||
mkdir %CheckpointsDir%\sd-vae-ft-mse
|
||||
mkdir %CheckpointsDir%\whisper
|
||||
|
||||
:: Install required packages
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
pip install gdown
|
||||
|
||||
:: Set HuggingFace endpoint
|
||||
set HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
:: Download MuseTalk weights
|
||||
huggingface-cli download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
|
||||
|
||||
:: Download SD VAE weights
|
||||
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
||||
|
||||
:: Download Whisper weights
|
||||
huggingface-cli download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||
|
||||
:: Download DWPose weights
|
||||
huggingface-cli download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
|
||||
|
||||
:: Download SyncNet weights
|
||||
huggingface-cli download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
|
||||
|
||||
:: Download Face Parse Bisent weights (using gdown)
|
||||
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O %CheckpointsDir%\face-parse-bisent\79999_iter.pth
|
||||
|
||||
:: Download ResNet weights
|
||||
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o %CheckpointsDir%\face-parse-bisent\resnet18-5c106cde.pth
|
||||
|
||||
echo All weights have been downloaded successfully!
|
||||
endlocal
|
||||
@@ -1,37 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Set the checkpoints directory
|
||||
CheckpointsDir="models"
|
||||
|
||||
# Create necessary directories
|
||||
mkdir -p $CheckpointsDir/{musetalk,musetalkV15,syncnet,dwpose,face-parse-bisent,sd-vae-ft-mse,whisper}
|
||||
|
||||
# Install required packages
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
pip install gdown
|
||||
|
||||
# Set HuggingFace endpoint
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# Download MuseTalk weights
|
||||
huggingface-cli download TMElyralab/MuseTalk --local-dir $CheckpointsDir
|
||||
|
||||
# Download SD VAE weights
|
||||
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir $CheckpointsDir/sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
||||
|
||||
# Download Whisper weights
|
||||
huggingface-cli download openai/whisper-tiny --local-dir $CheckpointsDir/whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||
|
||||
# Download DWPose weights
|
||||
huggingface-cli download yzd-v/DWPose --local-dir $CheckpointsDir/dwpose --include "dw-ll_ucoco_384.pth"
|
||||
|
||||
# Download SyncNet weights
|
||||
huggingface-cli download ByteDance/LatentSync --local-dir $CheckpointsDir/syncnet --include "latentsync_syncnet.pt"
|
||||
|
||||
# Download Face Parse Bisent weights (using gdown)
|
||||
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
|
||||
|
||||
# Download ResNet weights
|
||||
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
|
||||
|
||||
echo "All weights have been downloaded successfully!"
|
||||
@@ -1,72 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script runs inference based on the version and mode specified by the user.
|
||||
# Usage:
|
||||
# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
|
||||
# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
|
||||
|
||||
# Check if the correct number of arguments is provided
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <version> <mode>"
|
||||
echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the version and mode from the user input
|
||||
version=$1
|
||||
mode=$2
|
||||
|
||||
# Validate mode
|
||||
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
|
||||
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set config path based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
config_path="./configs/inference/test.yaml"
|
||||
result_dir="./results/test"
|
||||
else
|
||||
config_path="./configs/inference/realtime.yaml"
|
||||
result_dir="./results/realtime"
|
||||
fi
|
||||
|
||||
# Define the model paths based on the version
|
||||
if [ "$version" = "v1.0" ]; then
|
||||
model_dir="./models/musetalk"
|
||||
unet_model_path="$model_dir/pytorch_model.bin"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v1"
|
||||
elif [ "$version" = "v1.5" ]; then
|
||||
model_dir="./models/musetalkV15"
|
||||
unet_model_path="$model_dir/unet.pth"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v15"
|
||||
else
|
||||
echo "Invalid version specified. Please use v1.0 or v1.5."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set script name based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
script_name="scripts.inference"
|
||||
else
|
||||
script_name="scripts.realtime_inference"
|
||||
fi
|
||||
|
||||
# Base command arguments
|
||||
cmd_args="--inference_config $config_path \
|
||||
--result_dir $result_dir \
|
||||
--unet_model_path $unet_model_path \
|
||||
--unet_config $unet_config \
|
||||
--version $version_arg"
|
||||
|
||||
# Add realtime-specific arguments if in realtime mode
|
||||
if [ "$mode" = "realtime" ]; then
|
||||
cmd_args="$cmd_args \
|
||||
--fps 25 \
|
||||
--version $version_arg"
|
||||
fi
|
||||
|
||||
# Run inference
|
||||
python3 -m $script_name $cmd_args
|
||||
@@ -1,168 +0,0 @@
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
|
||||
class HParams:
|
||||
# copy from wav2lip
|
||||
def __init__(self):
|
||||
self.n_fft = 800
|
||||
self.hop_size = 200
|
||||
self.win_size = 800
|
||||
self.sample_rate = 16000
|
||||
self.frame_shift_ms = None
|
||||
self.signal_normalization = True
|
||||
|
||||
self.allow_clipping_in_normalization = True
|
||||
self.symmetric_mels = True
|
||||
self.max_abs_value = 4.0
|
||||
self.preemphasize = True
|
||||
self.preemphasis = 0.97
|
||||
self.min_level_db = -100
|
||||
self.ref_level_db = 20
|
||||
self.fmin = 55
|
||||
self.fmax=7600
|
||||
|
||||
self.use_lws=False
|
||||
self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
||||
self.rescale=True # Whether to rescale audio prior to preprocessing
|
||||
self.rescaling_max=0.9 # Rescaling value
|
||||
self.use_lws=False
|
||||
|
||||
|
||||
hp = HParams()
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.core.load(path, sr=sr)[0]
|
||||
#def load_wav(path, sr):
|
||||
# audio, sr_native = sf.read(path)
|
||||
# if sr != sr_native:
|
||||
# audio = librosa.resample(audio.T, sr_native, sr).T
|
||||
# return audio
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
#proposed by @dsmiller
|
||||
wavfile.write(path, sr, wav.astype(np.int16))
|
||||
|
||||
def save_wavenet_wav(wav, path, sr):
|
||||
librosa.output.write_wav(path, wav, sr=sr)
|
||||
|
||||
def preemphasis(wav, k, preemphasize=True):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
def get_hop_size():
|
||||
hop_size = hp.hop_size
|
||||
if hop_size is None:
|
||||
assert hp.frame_shift_ms is not None
|
||||
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
||||
return hop_size
|
||||
|
||||
def linearspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
||||
|
||||
if hp.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
def melspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
||||
|
||||
if hp.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
def _lws_processor():
|
||||
import lws
|
||||
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
||||
|
||||
def _stft(y):
|
||||
if hp.use_lws:
|
||||
return _lws_processor(hp).stft(y).T
|
||||
else:
|
||||
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
||||
|
||||
##########################################################
|
||||
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||
def num_frames(length, fsize, fshift):
|
||||
"""Compute number of time frames of spectrogram
|
||||
"""
|
||||
pad = (fsize - fshift)
|
||||
if length % fshift == 0:
|
||||
M = (length + pad * 2 - fsize) // fshift + 1
|
||||
else:
|
||||
M = (length + pad * 2 - fsize) // fshift + 2
|
||||
return M
|
||||
|
||||
|
||||
def pad_lr(x, fsize, fshift):
|
||||
"""Compute left and right padding
|
||||
"""
|
||||
M = num_frames(len(x), fsize, fshift)
|
||||
pad = (fsize - fshift)
|
||||
T = len(x) + 2 * pad
|
||||
r = (M - 1) * fshift + fsize - T
|
||||
return pad, pad + r
|
||||
##########################################################
|
||||
#Librosa correct padding
|
||||
def librosa_pad_lr(x, fsize, fshift):
|
||||
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||
|
||||
# Conversions
|
||||
_mel_basis = None
|
||||
|
||||
def _linear_to_mel(spectogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis()
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
def _build_mel_basis():
|
||||
assert hp.fmax <= hp.sample_rate // 2
|
||||
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
||||
fmin=hp.fmin, fmax=hp.fmax)
|
||||
|
||||
def _amp_to_db(x):
|
||||
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
||||
return 20 * np.log10(np.maximum(min_level, x))
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, (x) * 0.05)
|
||||
|
||||
def _normalize(S):
|
||||
if hp.allow_clipping_in_normalization:
|
||||
if hp.symmetric_mels:
|
||||
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
||||
-hp.max_abs_value, hp.max_abs_value)
|
||||
else:
|
||||
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
||||
|
||||
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
||||
if hp.symmetric_mels:
|
||||
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
||||
else:
|
||||
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
||||
|
||||
def _denormalize(D):
|
||||
if hp.allow_clipping_in_normalization:
|
||||
if hp.symmetric_mels:
|
||||
return (((np.clip(D, -hp.max_abs_value,
|
||||
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
||||
+ hp.min_level_db)
|
||||
else:
|
||||
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||
|
||||
if hp.symmetric_mels:
|
||||
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
||||
else:
|
||||
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||
@@ -1,607 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import torchvision.transforms as transforms
|
||||
from transformers import AutoFeatureExtractor
|
||||
import librosa
|
||||
import time
|
||||
import json
|
||||
import math
|
||||
from decord import AudioReader, VideoReader
|
||||
from decord.ndarray import cpu
|
||||
|
||||
from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
|
||||
from musetalk.data import audio
|
||||
|
||||
syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
|
||||
|
||||
|
||||
class FaceDataset(Dataset):
|
||||
"""Dataset class for loading and processing video data
|
||||
|
||||
Each video can be represented as:
|
||||
- Concatenated frame images
|
||||
- '.mp4' or '.gif' files
|
||||
- Folder containing all frames
|
||||
"""
|
||||
def __init__(self,
|
||||
cfg,
|
||||
list_paths,
|
||||
root_path='./dataset/',
|
||||
repeats=None):
|
||||
# Initialize dataset paths
|
||||
meta_paths = []
|
||||
if repeats is None:
|
||||
repeats = [1] * len(list_paths)
|
||||
assert len(repeats) == len(list_paths)
|
||||
|
||||
# Load data list
|
||||
for list_path, repeat_time in zip(list_paths, repeats):
|
||||
with open(list_path, 'r') as f:
|
||||
num = 0
|
||||
f.readline() # Skip header line
|
||||
for line in f.readlines():
|
||||
line_info = line.strip()
|
||||
meta = line_info.split()
|
||||
meta = meta[0]
|
||||
meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
|
||||
num += 1
|
||||
print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
|
||||
|
||||
# Set basic attributes
|
||||
self.meta_paths = meta_paths
|
||||
self.root_path = root_path
|
||||
self.image_size = cfg['image_size']
|
||||
self.min_face_size = cfg['min_face_size']
|
||||
self.T = cfg['T']
|
||||
self.sample_method = cfg['sample_method']
|
||||
self.top_k_ratio = cfg['top_k_ratio']
|
||||
self.max_attempts = 200
|
||||
self.padding_pixel_mouth = cfg['padding_pixel_mouth']
|
||||
|
||||
# Cropping related parameters
|
||||
self.crop_type = cfg['crop_type']
|
||||
self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
|
||||
self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
|
||||
self.random_margin_method = cfg['random_margin_method']
|
||||
|
||||
# Image transformations
|
||||
self.to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
self.pose_to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
# Feature extractor
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
|
||||
self.contorl_face_min_size = cfg["contorl_face_min_size"]
|
||||
|
||||
print("The sample method is: ", self.sample_method)
|
||||
print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
|
||||
|
||||
def generate_random_value(self):
|
||||
"""Generate random value
|
||||
|
||||
Returns:
|
||||
float: Generated random value
|
||||
"""
|
||||
if self.random_margin_method == "uniform":
|
||||
random_value = np.random.uniform(
|
||||
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
||||
self.jaw2edge_margin_mean + self.jaw2edge_margin_std
|
||||
)
|
||||
elif self.random_margin_method == "normal":
|
||||
random_value = np.random.normal(
|
||||
loc=self.jaw2edge_margin_mean,
|
||||
scale=self.jaw2edge_margin_std
|
||||
)
|
||||
random_value = np.clip(
|
||||
random_value,
|
||||
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
||||
self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
|
||||
return max(0, random_value)
|
||||
|
||||
def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
|
||||
"""Dynamically crop image with dynamic margin
|
||||
|
||||
Args:
|
||||
img: Input image
|
||||
original_bbox: Original bounding box
|
||||
extra_margin: Extra margin
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2, extra_margin)
|
||||
"""
|
||||
if extra_margin is None:
|
||||
extra_margin = self.generate_random_value()
|
||||
w, h = img.size
|
||||
x1, y1, x2, y2 = original_bbox
|
||||
y2 = min(y2 + int(extra_margin), h)
|
||||
return x1, y1, x2, y2, extra_margin
|
||||
|
||||
def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
|
||||
"""Crop and resize image
|
||||
|
||||
Args:
|
||||
img: Input image
|
||||
bbox: Bounding box
|
||||
crop_type: Type of cropping
|
||||
extra_margin: Extra margin
|
||||
|
||||
Returns:
|
||||
tuple: (Processed image, extra_margin, mask_scaled_factor)
|
||||
"""
|
||||
mask_scaled_factor = 1.
|
||||
if crop_type == 'crop_resize':
|
||||
x1, y1, x2, y2 = bbox
|
||||
img = img.crop((x1, y1, x2, y2))
|
||||
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
elif crop_type == 'dynamic_margin_crop_resize':
|
||||
x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
|
||||
w_original, _ = img.size
|
||||
img = img.crop((x1, y1, x2, y2))
|
||||
w_cropped, _ = img.size
|
||||
mask_scaled_factor = w_cropped / w_original
|
||||
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
elif crop_type == 'resize':
|
||||
w, h = img.size
|
||||
scale = np.sqrt(self.image_size ** 2 / (h * w))
|
||||
new_w = int(w * scale) / 64 * 64
|
||||
new_h = int(h * scale) / 64 * 64
|
||||
img = img.resize((new_w, new_h), Image.LANCZOS)
|
||||
return img, extra_margin, mask_scaled_factor
|
||||
|
||||
def get_audio_file(self, wav_path, start_index):
|
||||
"""Get audio file features
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
start_index: Starting index
|
||||
|
||||
Returns:
|
||||
tuple: (Audio features, start index)
|
||||
"""
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
audio_input_librosa, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
|
||||
while start_index >= 25 * 30:
|
||||
audio_input = audio_input_librosa[16000*30:]
|
||||
start_index -= 25 * 30
|
||||
if start_index + 2 * 25 >= 25 * 30:
|
||||
start_index -= 4 * 25
|
||||
audio_input = audio_input_librosa[16000*4:16000*34]
|
||||
else:
|
||||
audio_input = audio_input_librosa[:16000*30]
|
||||
|
||||
assert 2 * (start_index) >= 0
|
||||
assert 2 * (start_index + 2 * 25) <= 1500
|
||||
|
||||
audio_input = self.feature_extractor(
|
||||
audio_input,
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
return audio_input, start_index
|
||||
|
||||
def get_audio_file_mel(self, wav_path, start_index):
|
||||
"""Get mel spectrogram of audio file
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
start_index: Starting index
|
||||
|
||||
Returns:
|
||||
tuple: (Mel spectrogram, start index)
|
||||
"""
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
|
||||
audio_input, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
|
||||
audio_input = self.mel_feature_extractor(audio_input)
|
||||
return audio_input, start_index
|
||||
|
||||
def mel_feature_extractor(self, audio_input):
|
||||
"""Extract mel spectrogram features
|
||||
|
||||
Args:
|
||||
audio_input: Input audio
|
||||
|
||||
Returns:
|
||||
ndarray: Mel spectrogram features
|
||||
"""
|
||||
orig_mel = audio.melspectrogram(audio_input)
|
||||
return orig_mel.T
|
||||
|
||||
def crop_audio_window(self, spec, start_frame_num, fps=25):
|
||||
"""Crop audio window
|
||||
|
||||
Args:
|
||||
spec: Spectrogram
|
||||
start_frame_num: Starting frame number
|
||||
fps: Frames per second
|
||||
|
||||
Returns:
|
||||
ndarray: Cropped spectrogram
|
||||
"""
|
||||
start_idx = int(80. * (start_frame_num / float(fps)))
|
||||
end_idx = start_idx + syncnet_mel_step_size
|
||||
return spec[start_idx: end_idx, :]
|
||||
|
||||
def get_syncnet_input(self, video_path):
|
||||
"""Get SyncNet input features
|
||||
|
||||
Args:
|
||||
video_path: Video file path
|
||||
|
||||
Returns:
|
||||
ndarray: SyncNet input features
|
||||
"""
|
||||
ar = AudioReader(video_path, sample_rate=16000)
|
||||
original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
|
||||
return original_mel.T
|
||||
|
||||
def get_resized_mouth_mask(
|
||||
self,
|
||||
img_resized,
|
||||
landmark_array,
|
||||
face_shape,
|
||||
padding_pixel_mouth=0,
|
||||
image_size=256,
|
||||
crop_margin=0
|
||||
):
|
||||
landmark_array = np.array(landmark_array)
|
||||
resized_landmark = resize_landmark(
|
||||
landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
|
||||
|
||||
landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
|
||||
min_x, min_y = np.min(landmark_array, axis=0)
|
||||
max_x, max_y = np.max(landmark_array, axis=0)
|
||||
min_x = min_x - padding_pixel_mouth
|
||||
max_x = max_x + padding_pixel_mouth
|
||||
|
||||
# Calculate x-axis length and use it for y-axis
|
||||
width = max_x - min_x
|
||||
|
||||
# Calculate old center point
|
||||
center_y = (max_y + min_y) / 2
|
||||
|
||||
# Determine new min_y and max_y based on width
|
||||
min_y = center_y - width / 4
|
||||
max_y = center_y + width / 4
|
||||
|
||||
# Adjust mask position for dynamic crop, shift y-axis
|
||||
min_y = min_y - crop_margin
|
||||
max_y = max_y - crop_margin
|
||||
|
||||
# Prevent out of bounds
|
||||
min_x = max(min_x, 0)
|
||||
min_y = max(min_y, 0)
|
||||
max_x = min(max_x, face_shape[0])
|
||||
max_y = min(max_y, face_shape[1])
|
||||
|
||||
mask = np.zeros_like(np.array(img_resized))
|
||||
mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
|
||||
return Image.fromarray(mask)
|
||||
|
||||
def __len__(self):
|
||||
return 100000
|
||||
|
||||
def __getitem__(self, idx):
|
||||
attempts = 0
|
||||
while attempts < self.max_attempts:
|
||||
try:
|
||||
meta_path = random.sample(self.meta_paths, k=1)[0]
|
||||
with open(meta_path, 'r') as f:
|
||||
meta_data = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"meta file error:{meta_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
video_path = meta_data["mp4_path"]
|
||||
wav_path = meta_data["wav_path"]
|
||||
bbox_list = meta_data["face_list"]
|
||||
landmark_list = meta_data["landmark_list"]
|
||||
T = self.T
|
||||
|
||||
s = 0
|
||||
e = meta_data["frames"]
|
||||
len_valid_clip = e - s
|
||||
|
||||
if len_valid_clip < T * 10:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has less than {T * 10} frames")
|
||||
continue
|
||||
|
||||
try:
|
||||
cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
|
||||
total_frames = len(cap)
|
||||
assert total_frames == len(landmark_list)
|
||||
assert total_frames == len(bbox_list)
|
||||
landmark_shape = np.array(landmark_list).shape
|
||||
if landmark_shape != (total_frames, 68, 2):
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"video file error:{video_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
|
||||
landmark_list,
|
||||
bbox_list
|
||||
)
|
||||
if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
|
||||
print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
step = 1
|
||||
drive_idx_start = random.randint(s, e - T * step)
|
||||
drive_idx_list = list(
|
||||
range(drive_idx_start, drive_idx_start + T * step, step))
|
||||
assert len(drive_idx_list) == T
|
||||
|
||||
src_idx_list = []
|
||||
list_index_out_of_range = False
|
||||
for drive_idx in drive_idx_list:
|
||||
src_idx = get_src_idx(
|
||||
drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
|
||||
if src_idx is None:
|
||||
list_index_out_of_range = True
|
||||
break
|
||||
src_idx = min(src_idx, e - 1)
|
||||
src_idx = max(src_idx, s)
|
||||
src_idx_list.append(src_idx)
|
||||
|
||||
if list_index_out_of_range:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid source index for drive frames")
|
||||
continue
|
||||
|
||||
ref_face_valid_flag = True
|
||||
extra_margin = self.generate_random_value()
|
||||
|
||||
# Get reference images
|
||||
ref_imgs = []
|
||||
for src_idx in src_idx_list:
|
||||
imSrc = Image.fromarray(cap[src_idx].asnumpy())
|
||||
bbox_s = bbox_list_union[src_idx]
|
||||
imSrc, _, _ = self.crop_resize_img(
|
||||
imSrc,
|
||||
bbox_s,
|
||||
self.crop_type,
|
||||
extra_margin=None
|
||||
)
|
||||
if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
|
||||
ref_face_valid_flag = False
|
||||
break
|
||||
ref_imgs.append(imSrc)
|
||||
|
||||
if not ref_face_valid_flag:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
|
||||
continue
|
||||
|
||||
# Get target images and masks
|
||||
imSameIDs = []
|
||||
bboxes = []
|
||||
face_masks = []
|
||||
face_mask_valid = True
|
||||
target_face_valid_flag = True
|
||||
|
||||
for drive_idx in drive_idx_list:
|
||||
imSameID = Image.fromarray(cap[drive_idx].asnumpy())
|
||||
bbox_s = bbox_list_union[drive_idx]
|
||||
imSameID, _ , mask_scaled_factor = self.crop_resize_img(
|
||||
imSameID,
|
||||
bbox_s,
|
||||
self.crop_type,
|
||||
extra_margin=extra_margin
|
||||
)
|
||||
if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
|
||||
target_face_valid_flag = False
|
||||
break
|
||||
crop_margin = extra_margin * mask_scaled_factor
|
||||
face_mask = self.get_resized_mouth_mask(
|
||||
imSameID,
|
||||
shift_landmarks[drive_idx],
|
||||
face_shapes[drive_idx],
|
||||
self.padding_pixel_mouth,
|
||||
self.image_size,
|
||||
crop_margin=crop_margin
|
||||
)
|
||||
if np.count_nonzero(face_mask) == 0:
|
||||
face_mask_valid = False
|
||||
break
|
||||
|
||||
if face_mask.size[1] == 0 or face_mask.size[0] == 0:
|
||||
print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
|
||||
face_mask_valid = False
|
||||
break
|
||||
|
||||
imSameIDs.append(imSameID)
|
||||
bboxes.append(bbox_s)
|
||||
face_masks.append(face_mask)
|
||||
|
||||
if not face_mask_valid:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid face mask")
|
||||
continue
|
||||
|
||||
if not target_face_valid_flag:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
|
||||
continue
|
||||
|
||||
# Process audio features
|
||||
audio_offset = drive_idx_list[0]
|
||||
audio_step = step
|
||||
fps = 25.0 / step
|
||||
|
||||
try:
|
||||
audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
|
||||
_, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
|
||||
audio_feature_mel = self.get_syncnet_input(video_path)
|
||||
except Exception as e:
|
||||
print(f"audio file error:{wav_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
mel = self.crop_audio_window(audio_feature_mel, audio_offset)
|
||||
if mel.shape[0] != syncnet_mel_step_size:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
|
||||
continue
|
||||
|
||||
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
||||
|
||||
# Build sample dictionary
|
||||
sample = dict(
|
||||
pixel_values_vid=torch.stack(
|
||||
[self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
|
||||
pixel_values_ref_img=torch.stack(
|
||||
[self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
|
||||
pixel_values_face_mask=torch.stack(
|
||||
[self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
|
||||
audio_feature=audio_feature[0],
|
||||
audio_offset=audio_offset,
|
||||
audio_step=audio_step,
|
||||
mel=mel,
|
||||
wav_path=wav_path,
|
||||
fps=fps,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
raise ValueError("Unable to find a valid sample after maximum attempts.")
|
||||
|
||||
class HDTFDataset(FaceDataset):
|
||||
"""HDTF dataset class"""
|
||||
def __init__(self, cfg):
|
||||
root_path = './dataset/HDTF/meta'
|
||||
list_paths = [
|
||||
'./dataset/HDTF/train.txt',
|
||||
]
|
||||
|
||||
|
||||
repeats = [10]
|
||||
super().__init__(cfg, list_paths, root_path, repeats)
|
||||
print('HDTFDataset: ', len(self))
|
||||
|
||||
class VFHQDataset(FaceDataset):
|
||||
"""VFHQ dataset class"""
|
||||
def __init__(self, cfg):
|
||||
root_path = './dataset/VFHQ/meta'
|
||||
list_paths = [
|
||||
'./dataset/VFHQ/train.txt',
|
||||
]
|
||||
repeats = [1]
|
||||
super().__init__(cfg, list_paths, root_path, repeats)
|
||||
print('VFHQDataset: ', len(self))
|
||||
|
||||
def PortraitDataset(cfg=None):
|
||||
"""Return dataset based on configuration
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Dataset: Combined dataset
|
||||
"""
|
||||
if cfg["dataset_key"] == "HDTF":
|
||||
return ConcatDataset([HDTFDataset(cfg)])
|
||||
elif cfg["dataset_key"] == "VFHQ":
|
||||
return ConcatDataset([VFHQDataset(cfg)])
|
||||
else:
|
||||
print("############ use all dataset ############ ")
|
||||
return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Set random seeds for reproducibility
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Create dataset with configuration parameters
|
||||
dataset = PortraitDataset(cfg={
|
||||
'T': 1, # Number of frames to process at once
|
||||
'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
|
||||
'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
|
||||
'image_size': 256, # Size of processed images (height and width)
|
||||
'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
|
||||
'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
|
||||
'contorl_face_min_size': True, # Whether to enforce minimum face size
|
||||
'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
|
||||
'min_face_size': 200, # Minimum face size requirement for dataset
|
||||
'whisper_path': "./models/whisper", # Path to Whisper model
|
||||
'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
|
||||
'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
|
||||
'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
|
||||
})
|
||||
print(len(dataset))
|
||||
|
||||
import torchvision
|
||||
os.makedirs('debug', exist_ok=True)
|
||||
for i in range(10): # Check 10 samples
|
||||
sample = dataset[0]
|
||||
print(f"processing {i}")
|
||||
|
||||
# Get images and mask
|
||||
ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
|
||||
target_img = (sample['pixel_values_vid'] + 1.0) / 2
|
||||
face_mask = sample['pixel_values_face_mask']
|
||||
|
||||
# Print dimension information
|
||||
print(f"ref_img shape: {ref_img.shape}")
|
||||
print(f"target_img shape: {target_img.shape}")
|
||||
print(f"face_mask shape: {face_mask.shape}")
|
||||
|
||||
# Create visualization images
|
||||
b, c, h, w = ref_img.shape
|
||||
|
||||
# Apply mask only to target image
|
||||
target_mask = face_mask
|
||||
|
||||
# Keep reference image unchanged
|
||||
ref_with_mask = ref_img.clone()
|
||||
|
||||
# Create mask overlay for target image
|
||||
target_with_mask = target_img.clone()
|
||||
target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
|
||||
|
||||
# Save original images, mask, and overlay results
|
||||
# First row: original images
|
||||
# Second row: mask
|
||||
# Third row: overlay effect
|
||||
concatenated_img = torch.cat((
|
||||
ref_img, target_img, # Original images
|
||||
torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
|
||||
ref_with_mask, target_with_mask # Overlay effect
|
||||
), dim=3)
|
||||
|
||||
torchvision.utils.save_image(
|
||||
concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)
|
||||
@@ -1,233 +0,0 @@
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def summarize_tensor(x):
|
||||
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
|
||||
|
||||
def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
|
||||
num_landmarks = len(landmarks_list)
|
||||
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
||||
print(np.shape(landmarks_list))
|
||||
## Calculate mouth opening ratios
|
||||
for i, landmarks in enumerate(landmarks_list):
|
||||
# Assuming landmarks are in the format [x, y] and accessible by index
|
||||
mouth_top = landmarks[165] # Adjust index according to your landmarks format
|
||||
mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
|
||||
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
||||
mouth_open_ratios[i] = mouth_open_ratio
|
||||
|
||||
# Calculate differences matrix
|
||||
differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
|
||||
differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
|
||||
print(differences_matrix.shape)
|
||||
# Find top_k similar indices for each landmark set
|
||||
if ascending:
|
||||
top_indices = np.argsort(differences_matrix[i])[:top_k]
|
||||
else:
|
||||
top_indices = np.argsort(-differences_matrix[i])[:top_k]
|
||||
similar_landmarks_indices = top_indices.tolist()
|
||||
similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
|
||||
|
||||
return similar_landmarks_indices, similar_landmarks_distances
|
||||
#############################################################################################
|
||||
def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
|
||||
num_landmarks = len(landmarks_list)
|
||||
|
||||
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
||||
## Calculate mouth opening ratios
|
||||
#print("landmarks shape",np.shape(landmarks_list))
|
||||
for i, landmarks in enumerate(landmarks_list):
|
||||
# Assuming landmarks are in the format [x, y] and accessible by index
|
||||
#print(landmarks[165])
|
||||
mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
|
||||
mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
|
||||
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
||||
mouth_open_ratios[i] = mouth_open_ratio
|
||||
|
||||
# Find top_k similar indices for each landmark set
|
||||
if ascending:
|
||||
top_indices = np.argsort(mouth_open_ratios)[:top_k]
|
||||
else:
|
||||
top_indices = np.argsort(-mouth_open_ratios)[:top_k]
|
||||
return top_indices
|
||||
|
||||
def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
|
||||
"""
|
||||
Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
|
||||
|
||||
Parameters:
|
||||
landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
|
||||
image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
|
||||
start_index (int): The starting index of the facial landmarks.
|
||||
end_index (int): The ending index of the facial landmarks.
|
||||
top_k (int): The number of most similar landmark sets to return. Default is 50.
|
||||
ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
|
||||
|
||||
Returns:
|
||||
similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
|
||||
resized_landmarks (list): A list containing the resized facial landmarks.
|
||||
"""
|
||||
num_landmarks = len(landmarks_list)
|
||||
resized_landmarks = []
|
||||
|
||||
# Preprocess landmarks
|
||||
for i in range(num_landmarks):
|
||||
landmark_array = np.array(landmarks_list[i])
|
||||
selected_landmarks = landmark_array[start_index:end_index]
|
||||
resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
|
||||
resized_landmarks.append(resized_landmark)
|
||||
|
||||
resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
|
||||
|
||||
# Calculate similarity
|
||||
distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
|
||||
overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
|
||||
|
||||
if ascending:
|
||||
sorted_indices = np.argsort(overall_distances)
|
||||
similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
|
||||
else:
|
||||
sorted_indices = np.argsort(-overall_distances)
|
||||
similar_landmarks_indices = sorted_indices[0:top_k].tolist()
|
||||
|
||||
return similar_landmarks_indices
|
||||
|
||||
def process_bbox_musetalk(face_array, landmark_array):
|
||||
x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
|
||||
x_min_lm = min([int(x) for x, y in landmark_array])
|
||||
y_min_lm = min([int(y) for x, y in landmark_array])
|
||||
x_max_lm = max([int(x) for x, y in landmark_array])
|
||||
y_max_lm = max([int(y) for x, y in landmark_array])
|
||||
x_min = min(x_min_face, x_min_lm)
|
||||
y_min = min(y_min_face, y_min_lm)
|
||||
x_max = max(x_max_face, x_max_lm)
|
||||
y_max = max(y_max_face, y_max_lm)
|
||||
|
||||
x_min = max(x_min, 0)
|
||||
y_min = max(y_min, 0)
|
||||
|
||||
return [x_min, y_min, x_max, y_max]
|
||||
|
||||
def shift_landmarks_to_face_coordinates(landmark_list, face_list):
|
||||
"""
|
||||
Translates the data in landmark_list to the coordinates of the cropped larger face.
|
||||
|
||||
Parameters:
|
||||
landmark_list (list): A list containing multiple sets of facial landmarks.
|
||||
face_list (list): A list containing multiple facial images.
|
||||
|
||||
Returns:
|
||||
landmark_list_shift (list): The list of translated landmarks.
|
||||
bbox_union (list): The list of union bounding boxes.
|
||||
face_shapes (list): The list of facial shapes.
|
||||
"""
|
||||
landmark_list_shift = []
|
||||
bbox_union = []
|
||||
face_shapes = []
|
||||
|
||||
for i in range(len(face_list)):
|
||||
landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
|
||||
face_array = face_list[i]
|
||||
f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
|
||||
x_min, y_min, x_max, y_max = f_landmark_bbox
|
||||
landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
|
||||
landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
|
||||
landmark_list_shift.append(landmark_array)
|
||||
bbox_union.append(f_landmark_bbox)
|
||||
face_shapes.append((x_max - x_min, y_max - y_min))
|
||||
|
||||
return landmark_list_shift, bbox_union, face_shapes
|
||||
|
||||
def resize_landmark(landmark, w, h, new_w, new_h):
|
||||
landmark_norm = landmark / [w, h]
|
||||
landmark_resized = landmark_norm * [new_w, new_h]
|
||||
|
||||
return landmark_resized
|
||||
|
||||
def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
|
||||
"""
|
||||
Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
|
||||
|
||||
Parameters:
|
||||
- drive_idx (int): The current drive index.
|
||||
- T (int): Total number of frames or a specific range limit.
|
||||
- sample_method (str): Sampling method, which can be "random" or other methods.
|
||||
- landmarks_list (list): List of facial landmarks.
|
||||
- image_shapes (list): List of image shapes.
|
||||
- top_k_ratio (float): Ratio for selecting top k similar frames.
|
||||
|
||||
Returns:
|
||||
- src_idx (int): The calculated source index.
|
||||
"""
|
||||
if sample_method == "random":
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
elif sample_method == "pose_similarity":
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
# facial contour
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
src_idx = random.choice(pose_similarity_list)
|
||||
while abs(src_idx-drive_idx)<5:
|
||||
src_idx = random.choice(pose_similarity_list)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
elif sample_method=="pose_similarity_and_closed_mouth":
|
||||
# facial contour
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
|
||||
#print("closed_mouth_list",closed_mouth_list)
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
#print("pose_similarity_list",pose_similarity_list)
|
||||
common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
|
||||
if len(common_list) == 0:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
else:
|
||||
src_idx = random.choice(common_list)
|
||||
|
||||
while abs(src_idx-drive_idx) <5:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
elif sample_method=="pose_similarity_and_mouth_dissimilarity":
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
|
||||
# facial contour for 68 landmarks format
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
|
||||
# Mouth inner coutour for 68 landmarks format
|
||||
landmark_start_idx = 60
|
||||
landmark_end_idx = 67
|
||||
|
||||
mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
|
||||
|
||||
common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
|
||||
if len(common_list) == 0:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
else:
|
||||
src_idx = random.choice(common_list)
|
||||
|
||||
while abs(src_idx-drive_idx) <5:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown sample_method: {sample_method}")
|
||||
return src_idx
|
||||
@@ -1,81 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolate, self).__init__()
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, input):
|
||||
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
||||
|
||||
def set_requires_grad(net, requires_grad=False):
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
pyramid_scale = [1, 0.5, 0.25, 0.125]
|
||||
vgg_IN = vgg_face.Vgg19().to(device)
|
||||
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
|
||||
|
||||
image = torch.rand(8, 3, 256, 256).to(device)
|
||||
image_pred = torch.rand(8, 3, 256, 256).to(device)
|
||||
pyramide_real = pyramid(downsampler(image))
|
||||
pyramide_generated = pyramid(downsampler(image_pred))
|
||||
|
||||
|
||||
loss_IN = 0
|
||||
for scale in cfg.loss_params.pyramid_scale:
|
||||
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
||||
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
||||
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
||||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||
loss_IN += weight * value
|
||||
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
|
||||
print(loss_IN)
|
||||
|
||||
#print(cfg.model_params.discriminator_params)
|
||||
|
||||
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
|
||||
discriminator_full = DiscriminatorFullModel(discriminator)
|
||||
disc_scales = cfg.model_params.discriminator_params.scales
|
||||
# Prepare optimizer and loss function
|
||||
optimizer_D = optim.AdamW(discriminator.parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
scheduler_D = CosineAnnealingLR(optimizer_D,
|
||||
T_max=cfg.discriminator_train_params.epochs,
|
||||
eta_min=1e-6)
|
||||
|
||||
discriminator.train()
|
||||
|
||||
set_requires_grad(discriminator, False)
|
||||
|
||||
loss_G = 0.
|
||||
discriminator_maps_generated = discriminator(pyramide_generated)
|
||||
discriminator_maps_real = discriminator(pyramide_real)
|
||||
|
||||
for scale in disc_scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
|
||||
print(loss_G)
|
||||
@@ -1,44 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
return self.act(out)
|
||||
|
||||
class nonorm_Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
)
|
||||
self.act = nn.LeakyReLU(0.01, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
|
||||
class Conv2dTranspose(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
@@ -1,145 +0,0 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from musetalk.loss.vgg_face import ImagePyramide
|
||||
|
||||
class DownBlock2d(nn.Module):
|
||||
"""
|
||||
Simple block for processing video (encoder).
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
||||
super(DownBlock2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
||||
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
if norm:
|
||||
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
||||
else:
|
||||
self.norm = None
|
||||
self.pool = pool
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
out = self.conv(out)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
out = F.leaky_relu(out, 0.2)
|
||||
if self.pool:
|
||||
out = F.avg_pool2d(out, (2, 2))
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
"""
|
||||
Discriminator similar to Pix2Pix
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
||||
sn=False, **kwargs):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
down_blocks = []
|
||||
for i in range(num_blocks):
|
||||
down_blocks.append(
|
||||
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
def forward(self, x):
|
||||
feature_maps = []
|
||||
out = x
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
feature_maps.append(down_block(out))
|
||||
out = feature_maps[-1]
|
||||
prediction_map = self.conv(out)
|
||||
|
||||
return feature_maps, prediction_map
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
"""
|
||||
Multi-scale (scale) discriminator
|
||||
"""
|
||||
|
||||
def __init__(self, scales=(), **kwargs):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.scales = scales
|
||||
discs = {}
|
||||
for scale in scales:
|
||||
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
||||
self.discs = nn.ModuleDict(discs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, disc in self.discs.items():
|
||||
scale = str(scale).replace('-', '.')
|
||||
key = 'prediction_' + scale
|
||||
#print(key)
|
||||
#print(x)
|
||||
feature_maps, prediction_map = disc(x[key])
|
||||
out_dict['feature_maps_' + scale] = feature_maps
|
||||
out_dict['prediction_map_' + scale] = prediction_map
|
||||
return out_dict
|
||||
|
||||
|
||||
|
||||
class DiscriminatorFullModel(torch.nn.Module):
|
||||
"""
|
||||
Merge all discriminator related updates into single model for better multi-gpu usage
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super(DiscriminatorFullModel, self).__init__()
|
||||
self.discriminator = discriminator
|
||||
self.scales = self.discriminator.scales
|
||||
print("scales",self.scales)
|
||||
self.pyramid = ImagePyramide(self.scales, 3)
|
||||
if torch.cuda.is_available():
|
||||
self.pyramid = self.pyramid.cuda()
|
||||
|
||||
self.zero_tensor = None
|
||||
|
||||
def get_zero_tensor(self, input):
|
||||
if self.zero_tensor is None:
|
||||
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
|
||||
self.zero_tensor.requires_grad_(False)
|
||||
return self.zero_tensor.expand_as(input)
|
||||
|
||||
def forward(self, x, generated, gan_mode='ls'):
|
||||
pyramide_real = self.pyramid(x)
|
||||
pyramide_generated = self.pyramid(generated.detach())
|
||||
|
||||
discriminator_maps_generated = self.discriminator(pyramide_generated)
|
||||
discriminator_maps_real = self.discriminator(pyramide_real)
|
||||
|
||||
value_total = 0
|
||||
for scale in self.scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
if gan_mode == 'hinge':
|
||||
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
|
||||
elif gan_mode == 'ls':
|
||||
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
|
||||
else:
|
||||
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
|
||||
|
||||
value_total += value
|
||||
|
||||
return value_total
|
||||
|
||||
def main():
|
||||
discriminator = MultiScaleDiscriminator(scales=[1],
|
||||
block_expansion=32,
|
||||
max_features=512,
|
||||
num_blocks=4,
|
||||
sn=True,
|
||||
image_channel=3,
|
||||
estimate_jacobian=False)
|
||||
@@ -1,152 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
__all__ = ['ResNet', 'resnet50']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, include_top=True):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.include_top = include_top
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 255.
|
||||
x = x.flip(1)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
|
||||
if not self.include_top:
|
||||
return x
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def resnet50(**kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
return model
|
||||
@@ -1,95 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .conv import Conv2d
|
||||
|
||||
logloss = nn.BCELoss(reduction="none")
|
||||
def cosine_loss(a, v, y):
|
||||
d = nn.functional.cosine_similarity(a, v)
|
||||
d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
|
||||
loss = logloss(d.unsqueeze(1), y).squeeze()
|
||||
loss = loss.mean()
|
||||
return loss, d
|
||||
|
||||
def get_sync_loss(
|
||||
audio_embed,
|
||||
gt_frames,
|
||||
pred_frames,
|
||||
syncnet,
|
||||
adapted_weight,
|
||||
frames_left_index=0,
|
||||
frames_right_index=16,
|
||||
):
|
||||
# 跟gt_frames做随机的插入交换,节省显存开销
|
||||
assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
|
||||
# 3通道图像
|
||||
frames_sync_loss = torch.cat(
|
||||
[gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
|
||||
axis=1
|
||||
)
|
||||
vision_embed = syncnet.get_image_embed(frames_sync_loss)
|
||||
y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
|
||||
loss, score = cosine_loss(audio_embed, vision_embed, y)
|
||||
return loss, score
|
||||
|
||||
class SyncNet_color(nn.Module):
|
||||
def __init__(self):
|
||||
super(SyncNet_color, self).__init__()
|
||||
|
||||
self.face_encoder = nn.Sequential(
|
||||
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
||||
|
||||
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
self.audio_encoder = nn.Sequential(
|
||||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
||||
face_embedding = self.face_encoder(face_sequences)
|
||||
audio_embedding = self.audio_encoder(audio_sequences)
|
||||
|
||||
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
||||
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
||||
|
||||
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
||||
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
||||
|
||||
|
||||
return audio_embedding, face_embedding
|
||||
@@ -1,237 +0,0 @@
|
||||
'''
|
||||
This part of code contains a pretrained vgg_face model.
|
||||
ref link: https://github.com/prlz77/vgg-face.pytorch
|
||||
'''
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo
|
||||
import pickle
|
||||
from musetalk.loss import resnet as ResNet
|
||||
|
||||
|
||||
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
|
||||
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
|
||||
|
||||
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
|
||||
MEAN_RGB = [
|
||||
0.367035294117647,
|
||||
0.41083294117647057,
|
||||
0.5066129411764705
|
||||
]
|
||||
def load_state_dict(model, fname):
|
||||
"""
|
||||
Set parameters converted from Caffe models authors of VGGFace2 provide.
|
||||
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
|
||||
|
||||
Arguments:
|
||||
model: model
|
||||
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
|
||||
"""
|
||||
with open(fname, 'rb') as f:
|
||||
weights = pickle.load(f, encoding='latin1')
|
||||
|
||||
own_state = model.state_dict()
|
||||
for name, param in weights.items():
|
||||
if name in own_state:
|
||||
try:
|
||||
own_state[name].copy_(torch.from_numpy(param))
|
||||
except Exception:
|
||||
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
|
||||
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
|
||||
else:
|
||||
raise KeyError('unexpected key "{}" in state_dict'.format(name))
|
||||
|
||||
|
||||
def vggface2(pretrained=True):
|
||||
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
|
||||
load_state_dict(vggface, VGG_FACE_PATH)
|
||||
return vggface
|
||||
|
||||
def vggface(pretrained=False, **kwargs):
|
||||
"""VGGFace model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns pre-trained model
|
||||
"""
|
||||
model = VggFace(**kwargs)
|
||||
if pretrained:
|
||||
state = torch.utils.model_zoo.load_url(MODEL_URL)
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
class VggFace(torch.nn.Module):
|
||||
def __init__(self, classes=2622):
|
||||
"""VGGFace model.
|
||||
|
||||
Face recognition network. It takes as input a Bx3x224x224
|
||||
batch of face images and gives as output a BxC score vector
|
||||
(C is the number of identities).
|
||||
Input images need to be scaled in the 0-1 range and then
|
||||
normalized with respect to the mean RGB used during training.
|
||||
|
||||
Args:
|
||||
classes (int): number of identities recognized by the
|
||||
network
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = _ConvBlock(3, 64, 64)
|
||||
self.conv2 = _ConvBlock(64, 128, 128)
|
||||
self.conv3 = _ConvBlock(128, 256, 256, 256)
|
||||
self.conv4 = _ConvBlock(256, 512, 512, 512)
|
||||
self.conv5 = _ConvBlock(512, 512, 512, 512)
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
|
||||
self.fc2 = torch.nn.Linear(4096, 4096)
|
||||
self.fc3 = torch.nn.Linear(4096, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(F.relu(self.fc1(x)))
|
||||
x = self.dropout(F.relu(self.fc2(x)))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ConvBlock(torch.nn.Module):
|
||||
"""A Convolutional block."""
|
||||
|
||||
def __init__(self, *units):
|
||||
"""Create a block with len(units) - 1 convolutions.
|
||||
|
||||
convolution number i transforms the number of channels from
|
||||
units[i - 1] to units[i] channels.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.convs = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(in_, out, 3, 1, 1)
|
||||
for in_, out in zip(units[:-1], units[1:])
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# Each convolution is followed by a ReLU, then the block is
|
||||
# concluded by a max pooling.
|
||||
for c in self.convs:
|
||||
x = F.relu(c(x))
|
||||
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
from torchvision import models
|
||||
class Vgg19(torch.nn.Module):
|
||||
"""
|
||||
Vgg19 network for perceptual loss.
|
||||
"""
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
X = (X - self.mean) / self.std
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
|
||||
from torch import nn
|
||||
class AntiAliasInterpolation2d(nn.Module):
|
||||
"""
|
||||
Band-limited downsampling, for better preservation of the input signal.
|
||||
"""
|
||||
def __init__(self, channels, scale):
|
||||
super(AntiAliasInterpolation2d, self).__init__()
|
||||
sigma = (1 / scale - 1) / 2
|
||||
kernel_size = 2 * round(sigma * 4) + 1
|
||||
self.ka = kernel_size // 2
|
||||
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
||||
|
||||
kernel_size = [kernel_size, kernel_size]
|
||||
sigma = [sigma, sigma]
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid(
|
||||
[
|
||||
torch.arange(size, dtype=torch.float32)
|
||||
for size in kernel_size
|
||||
]
|
||||
)
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer('weight', kernel)
|
||||
self.groups = channels
|
||||
self.scale = scale
|
||||
inv_scale = 1 / scale
|
||||
self.int_inv_scale = int(inv_scale)
|
||||
|
||||
def forward(self, input):
|
||||
if self.scale == 1.0:
|
||||
return input
|
||||
|
||||
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
||||
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
||||
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ImagePyramide(torch.nn.Module):
|
||||
"""
|
||||
Create image pyramide for computing pyramide perceptual loss.
|
||||
"""
|
||||
def __init__(self, scales, num_channels):
|
||||
super(ImagePyramide, self).__init__()
|
||||
downs = {}
|
||||
for scale in scales:
|
||||
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
||||
self.downs = nn.ModuleDict(downs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, down_module in self.downs.items():
|
||||
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
||||
return out_dict
|
||||
@@ -1,240 +0,0 @@
|
||||
"""
|
||||
This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.models.attention import Attention as CrossAttention, FeedForward
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SyncNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.audio_encoder = DownEncoder2D(
|
||||
in_channels=config["audio_encoder"]["in_channels"],
|
||||
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
||||
dropout=config["audio_encoder"]["dropout"],
|
||||
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
||||
)
|
||||
|
||||
self.visual_encoder = DownEncoder2D(
|
||||
in_channels=config["visual_encoder"]["in_channels"],
|
||||
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
||||
dropout=config["visual_encoder"]["dropout"],
|
||||
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
def forward(self, image_sequences, audio_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds, audio_embeds
|
||||
|
||||
def get_image_embed(self, image_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds
|
||||
|
||||
def get_audio_embed(self, audio_sequences):
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return audio_embeds
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
act_fn: str = "silu",
|
||||
downsample_factor=2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if act_fn == "relu":
|
||||
self.act_fn = nn.ReLU()
|
||||
elif act_fn == "silu":
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
if isinstance(downsample_factor, list):
|
||||
downsample_factor = tuple(downsample_factor)
|
||||
|
||||
if downsample_factor == 1:
|
||||
self.downsample_conv = None
|
||||
else:
|
||||
self.downsample_conv = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
||||
)
|
||||
self.pad = (0, 1, 0, 1)
|
||||
if isinstance(downsample_factor, tuple):
|
||||
if downsample_factor[0] == 1:
|
||||
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
||||
elif downsample_factor[1] == 1:
|
||||
self.pad = (1, 1, 0, 1)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
hidden_states += input_tensor
|
||||
|
||||
if self.downsample_conv is not None:
|
||||
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
||||
hidden_states = self.downsample_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionBlock2D(nn.Module):
|
||||
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
||||
super().__init__()
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
||||
)
|
||||
# inner_dim = dim_head * heads
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm3 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
||||
|
||||
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
||||
self.attn._use_memory_efficient_attention_xformers = True
|
||||
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
||||
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoder2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4 * 16,
|
||||
block_out_channels=[64, 128, 256, 256],
|
||||
downsample_factors=[2, 2, 2, 2],
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
attn_blocks=[1, 1, 1, 1],
|
||||
dropout: float = 0.0,
|
||||
act_fn="silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
# in
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# down
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
output_channels = block_out_channels[0]
|
||||
for i, block_out_channel in enumerate(block_out_channels):
|
||||
input_channels = output_channels
|
||||
output_channels = block_out_channel
|
||||
# is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = ResnetBlock2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=output_channels,
|
||||
downsample_factor=downsample_factors[i],
|
||||
norm_num_groups=norm_num_groups,
|
||||
dropout=dropout,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if attn_blocks[i] == 1:
|
||||
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
||||
self.down_blocks.append(attention_block)
|
||||
|
||||
# out
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.act_fn_out = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.act_fn_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -31,15 +31,11 @@ class UNet():
|
||||
unet_config,
|
||||
model_path,
|
||||
use_float16=False,
|
||||
device=None
|
||||
):
|
||||
with open(unet_config, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
self.model = UNet2DConditionModel(**unet_config)
|
||||
self.pe = PositionalEncoding(d_model=384)
|
||||
if device != None:
|
||||
self.device = device
|
||||
else:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
|
||||
Executable → Regular
@@ -1,102 +0,0 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||
|
||||
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
# Split audio into 30s segments
|
||||
segment_length = 30 * sampling_rate
|
||||
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
|
||||
|
||||
features = []
|
||||
for segment in segments:
|
||||
audio_feature = self.feature_extractor(
|
||||
segment,
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
if weight_dtype is not None:
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
features.append(audio_feature)
|
||||
|
||||
return features, len(librosa_output)
|
||||
|
||||
def get_whisper_chunk(
|
||||
self,
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=25,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
):
|
||||
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
|
||||
whisper_feature = []
|
||||
# Process multiple 30s mel input features
|
||||
for input_feature in whisper_input_features:
|
||||
input_feature = input_feature.to(device).to(weight_dtype)
|
||||
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2)
|
||||
whisper_feature.append(audio_feats)
|
||||
|
||||
whisper_feature = torch.cat(whisper_feature, dim=1)
|
||||
# Trim the last segment to remove padding
|
||||
sr = 16000
|
||||
audio_fps = 50
|
||||
fps = int(fps)
|
||||
whisper_idx_multiplier = audio_fps / fps
|
||||
num_frames = math.floor((librosa_length / sr) * fps)
|
||||
actual_length = math.floor((librosa_length / sr) * audio_fps)
|
||||
whisper_feature = whisper_feature[:,:actual_length,...]
|
||||
|
||||
# Calculate padding amount
|
||||
padding_nums = math.ceil(whisper_idx_multiplier)
|
||||
# Add padding at start and end
|
||||
whisper_feature = torch.cat([
|
||||
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
|
||||
whisper_feature,
|
||||
# Add extra padding to prevent out of bounds
|
||||
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
|
||||
], 1)
|
||||
|
||||
audio_prompts = []
|
||||
for frame_index in range(num_frames):
|
||||
try:
|
||||
audio_index = math.floor(frame_index * whisper_idx_multiplier)
|
||||
audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
|
||||
assert audio_clip.shape[1] == audio_feature_length_per_frame
|
||||
audio_prompts.append(audio_clip)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
print(f"whisper_feature.shape: {whisper_feature.shape}")
|
||||
print(f"audio_clip.shape: {audio_clip.shape}")
|
||||
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
|
||||
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
|
||||
exit()
|
||||
|
||||
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
|
||||
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
|
||||
return audio_prompts
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_processor = AudioProcessor()
|
||||
wav_path = "./2.wav"
|
||||
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
||||
print("Audio Feature shape:", audio_feature.shape)
|
||||
print("librosa_feature_length:", librosa_feature_length)
|
||||
|
||||
Executable → Regular
+53
-89
@@ -1,8 +1,9 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
from face_parsing import FaceParsing
|
||||
|
||||
fp = FaceParsing()
|
||||
|
||||
def get_crop_box(box, expand):
|
||||
x, y, x1, y1 = box
|
||||
@@ -12,86 +13,76 @@ def get_crop_box(box, expand):
|
||||
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
||||
return crop_box, s
|
||||
|
||||
|
||||
def face_seg(image, mode="raw", fp=None):
|
||||
"""
|
||||
对图像进行面部解析,生成面部区域的掩码。
|
||||
|
||||
Args:
|
||||
image (PIL.Image): 输入图像。
|
||||
|
||||
Returns:
|
||||
PIL.Image: 面部区域的掩码图像。
|
||||
"""
|
||||
seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
|
||||
def face_seg(image):
|
||||
seg_image = fp(image)
|
||||
if seg_image is None:
|
||||
print("error, no person_segment") # 如果没有检测到面部,返回错误
|
||||
print("error, no person_segment")
|
||||
return None
|
||||
|
||||
seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
|
||||
seg_image = seg_image.resize(image.size)
|
||||
return seg_image
|
||||
|
||||
def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
|
||||
#print(image.shape)
|
||||
#print(face.shape)
|
||||
|
||||
def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
|
||||
"""
|
||||
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
face = Image.fromarray(face[:,:,::-1])
|
||||
|
||||
Args:
|
||||
image (numpy.ndarray): 原始图像(身体部分)。
|
||||
face (numpy.ndarray): 裁剪的面部图像。
|
||||
face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
|
||||
upper_boundary_ratio (float): 用于控制面部区域的保留比例。
|
||||
expand (float): 扩展因子,用于放大裁剪框。
|
||||
mode: 融合mask构建方式
|
||||
x, y, x1, y1 = face_box
|
||||
#print(x1-x,y1-y)
|
||||
crop_box, s = get_crop_box(face_box, expand)
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
face_position = (x, y)
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 处理后的图像。
|
||||
"""
|
||||
# 将 numpy 数组转换为 PIL 图像
|
||||
body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
|
||||
face = Image.fromarray(face[:, :, ::-1]) # 面部图像
|
||||
|
||||
x, y, x1, y1 = face_box # 获取面部边界框的坐标
|
||||
crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
|
||||
x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
|
||||
face_position = (x, y) # 面部在原始图像中的位置
|
||||
|
||||
# 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
ori_shape = face_large.size # 裁剪后图像的原始尺寸
|
||||
mask_image = face_seg(face_large)
|
||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
# 对裁剪后的面部区域进行面部解析,生成掩码
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
|
||||
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
|
||||
|
||||
mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
|
||||
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
|
||||
|
||||
|
||||
# 保留面部区域的上半部分(用于控制说话区域)
|
||||
# keep upper_boundary_ratio of talking area
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
|
||||
modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
|
||||
top_boundary = int(height * upper_boundary_ratio)
|
||||
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||
|
||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||
mask_image = Image.fromarray(mask_array)
|
||||
|
||||
# 对掩码进行高斯模糊,使边缘更平滑
|
||||
blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
|
||||
#mask_array = np.array(modified_mask_image)
|
||||
mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
|
||||
|
||||
# 将裁剪的面部图像粘贴回扩展后的面部区域
|
||||
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
body = np.array(body)
|
||||
return body[:,:,::-1]
|
||||
|
||||
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
||||
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
|
||||
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
||||
x, y, x1, y1 = face_box
|
||||
#print(x1-x,y1-y)
|
||||
crop_box, s = get_crop_box(face_box, expand)
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
mask_image = face_seg(face_large)
|
||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
# keep upper_boundary_ratio of talking area
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio)
|
||||
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||
|
||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||
return mask_array,crop_box
|
||||
|
||||
def get_image_blending(image,face,face_box,mask_array,crop_box):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
@@ -107,30 +98,3 @@ def get_image_blending(image, face, face_box, mask_array, crop_box):
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
body = np.array(body)
|
||||
return body[:,:,::-1]
|
||||
|
||||
|
||||
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
#print(x1-x,y1-y)
|
||||
crop_box, s = get_crop_box(face_box, expand)
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
# keep upper_boundary_ratio of talking area
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio)
|
||||
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||
|
||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||
return mask_array, crop_box
|
||||
|
||||
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
@@ -8,53 +8,9 @@ from .model import BiSeNet
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
class FaceParsing():
|
||||
def __init__(self, left_cheek_width=80, right_cheek_width=80):
|
||||
def __init__(self):
|
||||
self.net = self.model_init()
|
||||
self.preprocess = self.image_preprocess()
|
||||
# Ensure all size parameters are integers
|
||||
cone_height = 21
|
||||
tail_height = 12
|
||||
total_size = cone_height + tail_height
|
||||
|
||||
# Create kernel with explicit integer dimensions
|
||||
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
|
||||
center_x = total_size // 2 # Ensure center coordinates are integers
|
||||
|
||||
# Cone part
|
||||
for row in range(cone_height):
|
||||
if row < cone_height//2:
|
||||
continue
|
||||
width = int(2 * (row - cone_height//2) + 1)
|
||||
start = int(center_x - (width // 2))
|
||||
end = int(center_x + (width // 2) + 1)
|
||||
kernel[row, start:end] = 1
|
||||
|
||||
# Vertical extension part
|
||||
if cone_height > 0:
|
||||
base_width = int(kernel[cone_height-1].sum())
|
||||
else:
|
||||
base_width = 1
|
||||
|
||||
for row in range(cone_height, total_size):
|
||||
start = max(0, int(center_x - (base_width//2)))
|
||||
end = min(total_size, int(center_x + (base_width//2) + 1))
|
||||
kernel[row, start:end] = 1
|
||||
self.kernel = kernel
|
||||
|
||||
# Modify cheek erosion kernel to be flatter ellipse
|
||||
self.cheek_kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (35, 3))
|
||||
|
||||
# Add cheek area mask (protect chin area)
|
||||
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
|
||||
|
||||
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
|
||||
"""Create cheek area mask (1/4 area on both sides)"""
|
||||
mask = np.zeros((512, 512), dtype=np.uint8)
|
||||
center = 512 // 2
|
||||
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
|
||||
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
|
||||
return mask
|
||||
|
||||
def model_init(self,
|
||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
@@ -74,7 +30,7 @@ class FaceParsing():
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, image, size=(512, 512), mode="raw"):
|
||||
def __call__(self, image, size=(512, 512)):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
|
||||
@@ -88,25 +44,8 @@ class FaceParsing():
|
||||
img = torch.unsqueeze(img, 0)
|
||||
out = self.net(img)[0]
|
||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||
|
||||
# Add 14:neck, remove 10:nose and 7:8:9
|
||||
if mode == "neck":
|
||||
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
elif mode == "jaw":
|
||||
face_region = np.isin(parsing, [1])*255
|
||||
face_region = face_region.astype(np.uint8)
|
||||
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
|
||||
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
|
||||
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
|
||||
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
|
||||
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
|
||||
parsing[np.isin(parsing, [11, 12, 13])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
else:
|
||||
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
|
||||
parsing[np.where(parsing>13)] = 0
|
||||
parsing[np.where(parsing>=1)] = 255
|
||||
parsing = Image.fromarray(parsing.astype(np.uint8))
|
||||
return parsing
|
||||
|
||||
|
||||
Executable → Regular
@@ -1,337 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import WhisperModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange
|
||||
|
||||
from musetalk.models.syncnet import SyncNet
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
|
||||
from musetalk.loss.basic_loss import Interpolate
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
from musetalk.data.dataset import PortraitDataset
|
||||
from musetalk.utils.utils import (
|
||||
get_image_pred,
|
||||
process_audio_features,
|
||||
process_and_save_images
|
||||
)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
):
|
||||
model_pred = self.unet(
|
||||
input_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_prompts
|
||||
).sample
|
||||
return model_pred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
|
||||
"""Initialize models and optimizers"""
|
||||
model_dict = {
|
||||
'vae': None,
|
||||
'unet': None,
|
||||
'net': None,
|
||||
'wav2vec': None,
|
||||
'optimizer': None,
|
||||
'lr_scheduler': None,
|
||||
'scheduler_max_steps': None,
|
||||
'trainable_params': None
|
||||
}
|
||||
|
||||
model_dict['vae'] = AutoencoderKL.from_pretrained(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
subfolder=cfg.vae_type,
|
||||
)
|
||||
|
||||
unet_config_file = os.path.join(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
cfg.unet_sub_folder + "/musetalk.json"
|
||||
)
|
||||
|
||||
with open(unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
model_dict['unet'] = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if not cfg.random_init_unet:
|
||||
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
|
||||
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
|
||||
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
|
||||
model_dict['unet'].load_state_dict(checkpoint)
|
||||
|
||||
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
|
||||
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
|
||||
|
||||
model_dict['vae'].requires_grad_(False)
|
||||
model_dict['unet'].requires_grad_(True)
|
||||
|
||||
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
model_dict['net'] = Net(model_dict['unet'])
|
||||
|
||||
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
|
||||
device="cuda", dtype=weight_dtype).eval()
|
||||
model_dict['wav2vec'].requires_grad_(False)
|
||||
|
||||
if cfg.solver.gradient_checkpointing:
|
||||
model_dict['unet'].enable_gradient_checkpointing()
|
||||
|
||||
if cfg.solver.scale_lr:
|
||||
learning_rate = (
|
||||
cfg.solver.learning_rate
|
||||
* cfg.solver.gradient_accumulation_steps
|
||||
* cfg.data.train_bs
|
||||
* accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
learning_rate = cfg.solver.learning_rate
|
||||
|
||||
if cfg.solver.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
|
||||
if accelerator.is_main_process:
|
||||
print('trainable params')
|
||||
for n, p in model_dict['net'].named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
model_dict['optimizer'] = optimizer_cls(
|
||||
model_dict['trainable_params'],
|
||||
lr=learning_rate,
|
||||
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
||||
weight_decay=cfg.solver.adam_weight_decay,
|
||||
eps=cfg.solver.adam_epsilon,
|
||||
)
|
||||
|
||||
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
|
||||
model_dict['lr_scheduler'] = get_scheduler(
|
||||
cfg.solver.lr_scheduler,
|
||||
optimizer=model_dict['optimizer'],
|
||||
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
|
||||
num_training_steps=model_dict['scheduler_max_steps'],
|
||||
)
|
||||
|
||||
return model_dict
|
||||
|
||||
def initialize_dataloaders(cfg):
|
||||
"""Initialize training and validation dataloaders"""
|
||||
dataloader_dict = {
|
||||
'train_dataset': None,
|
||||
'val_dataset': None,
|
||||
'train_dataloader': None,
|
||||
'val_dataloader': None
|
||||
}
|
||||
|
||||
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['train_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=cfg.data.num_workers,
|
||||
)
|
||||
|
||||
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['val_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
return dataloader_dict
|
||||
|
||||
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
|
||||
"""Initialize loss functions and discriminators"""
|
||||
loss_dict = {
|
||||
'L1_loss': nn.L1Loss(reduction='mean'),
|
||||
'discriminator': None,
|
||||
'mouth_discriminator': None,
|
||||
'optimizer_D': None,
|
||||
'mouth_optimizer_D': None,
|
||||
'scheduler_D': None,
|
||||
'mouth_scheduler_D': None,
|
||||
'disc_scales': None,
|
||||
'discriminator_full': None,
|
||||
'mouth_discriminator_full': None
|
||||
}
|
||||
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
loss_dict['discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
|
||||
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
|
||||
loss_dict['optimizer_D'] = optim.AdamW(
|
||||
loss_dict['discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
|
||||
loss_dict['mouth_optimizer_D'] = optim.AdamW(
|
||||
loss_dict['mouth_discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['mouth_optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
|
||||
def initialize_syncnet(cfg, accelerator, weight_dtype):
|
||||
"""Initialize SyncNet model"""
|
||||
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
|
||||
if cfg.data.n_sample_frames != 16:
|
||||
raise ValueError(
|
||||
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
|
||||
)
|
||||
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
|
||||
syncnet = SyncNet(OmegaConf.to_container(
|
||||
syncnet_config.model)).to(accelerator.device)
|
||||
print(
|
||||
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
|
||||
checkpoint = torch.load(
|
||||
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
|
||||
syncnet.load_state_dict(checkpoint["state_dict"])
|
||||
syncnet.to(dtype=weight_dtype)
|
||||
syncnet.requires_grad_(False)
|
||||
syncnet.eval()
|
||||
return syncnet
|
||||
return None
|
||||
|
||||
def initialize_vgg(cfg, accelerator):
|
||||
"""Initialize VGG model"""
|
||||
if cfg.loss_params.vgg_loss > 0:
|
||||
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
|
||||
pyramid = vgg_face.ImagePyramide(
|
||||
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(
|
||||
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
|
||||
return vgg_IN, pyramid, downsampler
|
||||
return None, None, None
|
||||
|
||||
def validation(
|
||||
cfg,
|
||||
val_dataloader,
|
||||
net,
|
||||
vae,
|
||||
wav2vec,
|
||||
accelerator,
|
||||
save_dir,
|
||||
global_step,
|
||||
weight_dtype,
|
||||
syncnet_score=1,
|
||||
):
|
||||
"""Validation function for model evaluation"""
|
||||
net.eval() # Set the model to evaluation mode
|
||||
for batch in val_dataloader:
|
||||
# The same ref_latents
|
||||
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
bsz, num_frames, c, h, w = ref_pixel_values.shape
|
||||
|
||||
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
|
||||
# audio feature for unet
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'b f c h w-> (b f) c h w'
|
||||
)
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'(b f) c h w -> (b f) (c h) w',
|
||||
b=bsz
|
||||
)
|
||||
# different masked_latents
|
||||
image_pred_train = get_image_pred(
|
||||
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
image_pred_infer = get_image_pred(
|
||||
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
|
||||
process_and_save_images(
|
||||
batch,
|
||||
image_pred_train,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
cfg.num_images_to_keep,
|
||||
syncnet_score
|
||||
)
|
||||
# only infer 1 image in validation
|
||||
break
|
||||
net.train() # Set the model back to training mode
|
||||
Executable → Regular
+20
-275
@@ -2,33 +2,26 @@ import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import shutil
|
||||
import os.path as osp
|
||||
|
||||
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
||||
if ffmpeg_path is None:
|
||||
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
|
||||
elif ffmpeg_path not in os.getenv('PATH'):
|
||||
print("add ffmpeg to path")
|
||||
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
||||
|
||||
|
||||
from musetalk.whisper.audio2feature import Audio2Feature
|
||||
from musetalk.models.vae import VAE
|
||||
from musetalk.models.unet import UNet,PositionalEncoding
|
||||
|
||||
|
||||
def load_all_model(
|
||||
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
|
||||
vae_type="sd-vae",
|
||||
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
|
||||
device=None,
|
||||
):
|
||||
vae = VAE(
|
||||
model_path = os.path.join("models", vae_type),
|
||||
)
|
||||
print(f"load unet model from {unet_model_path}")
|
||||
unet = UNet(
|
||||
unet_config=unet_config,
|
||||
model_path=unet_model_path,
|
||||
device=device
|
||||
)
|
||||
def load_all_model():
|
||||
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
|
||||
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
|
||||
unet = UNet(unet_config="./models/musetalk/musetalk.json",
|
||||
model_path ="./models/musetalk/pytorch_model.bin")
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
return vae, unet, pe
|
||||
return audio_processor,vae,unet,pe
|
||||
|
||||
def get_file_type(video_path):
|
||||
_, ext = os.path.splitext(video_path)
|
||||
@@ -46,13 +39,10 @@ def get_video_fps(video_path):
|
||||
video.release()
|
||||
return fps
|
||||
|
||||
def datagen(
|
||||
whisper_chunks,
|
||||
def datagen(whisper_chunks,
|
||||
vae_encode_latents,
|
||||
batch_size=8,
|
||||
delay_frame=0,
|
||||
device="cuda:0",
|
||||
):
|
||||
delay_frame=0):
|
||||
whisper_batch, latent_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(vae_encode_latents)
|
||||
@@ -61,259 +51,14 @@ def datagen(
|
||||
latent_batch.append(latent)
|
||||
|
||||
if len(latent_batch) >= batch_size:
|
||||
whisper_batch = torch.stack(whisper_batch)
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, latent_batch
|
||||
whisper_batch, latent_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(latent_batch) > 0:
|
||||
whisper_batch = torch.stack(whisper_batch)
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch.to(device), latent_batch.to(device)
|
||||
|
||||
def cast_training_params(
|
||||
model: Union[torch.nn.Module, List[torch.nn.Module]],
|
||||
dtype=torch.float32,
|
||||
):
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
|
||||
def rand_log_normal(
|
||||
shape,
|
||||
loc=0.,
|
||||
scale=1.,
|
||||
device='cpu',
|
||||
dtype=torch.float32,
|
||||
generator=None
|
||||
):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
rnd_normal = torch.randn(
|
||||
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
|
||||
sigma = (rnd_normal * scale + loc).exp()
|
||||
return sigma
|
||||
|
||||
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
|
||||
# Initialize lists to store the results for each image in the batch
|
||||
mouth_real_list = []
|
||||
mouth_generated_list = []
|
||||
|
||||
# Process each image in the batch
|
||||
for b in range(frames.shape[0]):
|
||||
# Find the non-zero area in the face mask
|
||||
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
|
||||
# If there are no non-zero indices, skip this image
|
||||
if non_zero_indices.numel() == 0:
|
||||
continue
|
||||
|
||||
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
|
||||
non_zero_indices[:, 1])
|
||||
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
|
||||
non_zero_indices[:, 2])
|
||||
|
||||
# Crop the frames and image_pred according to the non-zero area
|
||||
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
|
||||
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
|
||||
# Resize the cropped images to 256*256
|
||||
frames_resized = F.interpolate(frames_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
|
||||
# Append the resized images to the result lists
|
||||
mouth_real_list.append(frames_resized)
|
||||
mouth_generated_list.append(image_pred_resized)
|
||||
|
||||
# Convert the lists to tensors if they are not empty
|
||||
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
|
||||
mouth_generated = torch.cat(
|
||||
mouth_generated_list, dim=0) if mouth_generated_list else None
|
||||
|
||||
return mouth_real, mouth_generated
|
||||
|
||||
def get_image_pred(pixel_values,
|
||||
ref_pixel_values,
|
||||
audio_prompts,
|
||||
vae,
|
||||
net,
|
||||
weight_dtype):
|
||||
with torch.no_grad():
|
||||
bsz, num_frames, c, h, w = pixel_values.shape
|
||||
|
||||
masked_pixel_values = pixel_values.clone()
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
|
||||
masked_frames = rearrange(
|
||||
masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
masked_latents = vae.encode(masked_frames).latent_dist.mode()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
masked_latents = masked_latents.float()
|
||||
|
||||
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
|
||||
ref_latents = vae.encode(ref_frames).latent_dist.mode()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
ref_latents = ref_latents.float()
|
||||
|
||||
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
input_latents = input_latents.to(weight_dtype)
|
||||
timesteps = torch.tensor([0], device=input_latents.device)
|
||||
latents_pred = net(
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
)
|
||||
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
|
||||
image_pred = vae.decode(latents_pred).sample
|
||||
image_pred = image_pred.float()
|
||||
|
||||
return image_pred
|
||||
|
||||
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
|
||||
with torch.no_grad():
|
||||
audio_feature_length_per_frame = 2 * \
|
||||
(cfg.data.audio_padding_length_left +
|
||||
cfg.data.audio_padding_length_right + 1)
|
||||
audio_feats = batch['audio_feature'].to(weight_dtype)
|
||||
audio_feats = wav2vec.encoder(
|
||||
audio_feats, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
|
||||
|
||||
start_ts = batch['audio_offset']
|
||||
step_ts = batch['audio_step']
|
||||
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
|
||||
audio_feats,
|
||||
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
|
||||
audio_prompts = []
|
||||
for bb in range(bsz):
|
||||
audio_feats_list = []
|
||||
for f in range(num_frames):
|
||||
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
||||
audio_clip = audio_feats[bb:bb+1,
|
||||
cur_t: cur_t+audio_feature_length_per_frame]
|
||||
|
||||
audio_feats_list.append(audio_clip)
|
||||
audio_feats_list = torch.stack(audio_feats_list, 1)
|
||||
audio_prompts.append(audio_feats_list)
|
||||
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
|
||||
return audio_prompts
|
||||
|
||||
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
|
||||
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
|
||||
|
||||
if total_limit is not None:
|
||||
checkpoints = os.listdir(save_dir)
|
||||
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
|
||||
checkpoints = [d for d in checkpoints if name in d]
|
||||
checkpoints = sorted(
|
||||
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(checkpoints) >= total_limit:
|
||||
num_to_remove = len(checkpoints) - total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(
|
||||
f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(
|
||||
save_dir, removing_checkpoint)
|
||||
os.remove(removing_checkpoint)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
|
||||
unwarp_net = accelerator.unwrap_model(net)
|
||||
save_checkpoint(
|
||||
unwarp_net.unet,
|
||||
save_dir,
|
||||
global_step,
|
||||
name="unet",
|
||||
total_limit=cfg.total_limit,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
def delete_additional_ckpt(base_path, num_keep):
|
||||
dirs = []
|
||||
for d in os.listdir(base_path):
|
||||
if d.startswith("checkpoint-"):
|
||||
dirs.append(d)
|
||||
num_tot = len(dirs)
|
||||
if num_tot <= num_keep:
|
||||
return
|
||||
# ensure ckpt is sorted and delete the ealier!
|
||||
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
|
||||
for d in del_dirs:
|
||||
path_to_dir = osp.join(base_path, d)
|
||||
if osp.exists(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
|
||||
def seed_everything(seed):
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed % (2**32))
|
||||
random.seed(seed)
|
||||
|
||||
def process_and_save_images(
|
||||
batch,
|
||||
image_pred,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
num_images_to_keep=10,
|
||||
syncnet_score=1
|
||||
):
|
||||
# Rearrange the tensors
|
||||
print("image_pred.shape: ", image_pred.shape)
|
||||
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
|
||||
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Create masked pixel values
|
||||
masked_pixel_values = batch["pixel_values_vid"].clone()
|
||||
_, _, _, h, _ = batch["pixel_values_vid"].shape
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Keep only the specified number of images
|
||||
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
|
||||
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
|
||||
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
|
||||
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
|
||||
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
|
||||
|
||||
# Concatenate images
|
||||
concat = torch.cat([
|
||||
masked_pixel_values * 0.5 + 0.5,
|
||||
pixel_values_ref_img * 0.5 + 0.5,
|
||||
image_pred * 0.5 + 0.5,
|
||||
pixel_values * 0.5 + 0.5,
|
||||
image_pred_infer * 0.5 + 0.5,
|
||||
], dim=2)
|
||||
print("concat.shape: ", concat.shape)
|
||||
|
||||
# Create the save directory if it doesn't exist
|
||||
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
|
||||
|
||||
# Try to save the concatenated image
|
||||
try:
|
||||
# Concatenate images horizontally and convert to numpy array
|
||||
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
|
||||
# Save the image
|
||||
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
|
||||
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
|
||||
except Exception as e:
|
||||
print(f"Failed to save image: {e}")
|
||||
yield whisper_batch, latent_batch
|
||||
|
||||
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
Executable → Regular
+7
-6
@@ -1,15 +1,14 @@
|
||||
diffusers==0.30.2
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
torchaudio==2.0.2
|
||||
diffusers==0.27.2
|
||||
accelerate==0.28.0
|
||||
numpy==1.23.5
|
||||
tensorflow==2.12.0
|
||||
tensorboard==2.12.0
|
||||
opencv-python==4.9.0.80
|
||||
soundfile==0.12.1
|
||||
transformers==4.39.2
|
||||
huggingface_hub==0.30.2
|
||||
librosa==0.11.0
|
||||
einops==0.8.1
|
||||
gradio==5.24.0
|
||||
|
||||
gdown
|
||||
requests
|
||||
@@ -17,4 +16,6 @@ imageio[ffmpeg]
|
||||
|
||||
omegaconf
|
||||
ffmpeg-python
|
||||
gradio
|
||||
spaces
|
||||
moviepy
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
+257
@@ -0,0 +1,257 @@
|
||||
import cv2
|
||||
import os
|
||||
# import dlib
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import uuid
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
def get_largest_integer_filename(folder_path):
|
||||
# Check if the folder exists
|
||||
if not os.path.isdir(folder_path):
|
||||
return -1
|
||||
|
||||
# Get the list of files in the folder
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# Check if the folder is empty
|
||||
if not files:
|
||||
return -1
|
||||
|
||||
# Extract the integer part of filenames and find the largest
|
||||
largest_integer = -1
|
||||
for file in files:
|
||||
try:
|
||||
# Get the integer part of the filename
|
||||
file_int = int(os.path.splitext(file)[0])
|
||||
if file_int > largest_integer:
|
||||
largest_integer = file_int
|
||||
except ValueError:
|
||||
# Skip files that don't have an integer filename
|
||||
continue
|
||||
|
||||
return largest_integer
|
||||
|
||||
def datagen(whisper_chunks,
|
||||
crop_images,
|
||||
batch_size=8,
|
||||
delay_frame=0):
|
||||
whisper_batch, crop_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(crop_images)
|
||||
crop_image = crop_images[idx]
|
||||
whisper_batch.append(w)
|
||||
crop_batch.append(crop_image)
|
||||
|
||||
if len(crop_batch) >= batch_size:
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, crop_batch
|
||||
whisper_batch, crop_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(crop_batch) > 0:
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch, crop_batch
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
|
||||
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
|
||||
temp_audio_index=total_audio_index
|
||||
temp_image_index=total_image_index
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
folder_name = args.folder_name
|
||||
if not os.path.exists(f"data/images/{folder_name}/"):
|
||||
os.makedirs(f"data/images/{folder_name}")
|
||||
if not os.path.exists(f"data/audios/{folder_name}/"):
|
||||
os.makedirs(f"data/audios/{folder_name}")
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
|
||||
# Combine two consecutive chunks
|
||||
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
|
||||
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
|
||||
# Save the pair to a .npy file
|
||||
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
|
||||
temp_audio_index=(__//2)+total_audio_index+1
|
||||
total_audio_index=temp_audio_index
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
gc.collect()
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
crop_data=[]
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
x1=max(0,x1)
|
||||
y1=max(0,y1)
|
||||
x2=max(0,x2)
|
||||
y2=max(0,y2)
|
||||
|
||||
if ((y2-y1)<=0) or ((x2-x1)<=0):
|
||||
continue
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
crop_data.append(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
crop_data = crop_data + crop_data[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,crop_data,batch_size)
|
||||
crop_index = 0
|
||||
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
for image,audio in zip(crop_batch,whisper_batch):
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(crop_index+total_image_index+1)}.png",image)
|
||||
temp_image_index = crop_index + total_image_index + 1
|
||||
crop_index += 1
|
||||
|
||||
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
|
||||
total_image_index=temp_image_index
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
def process_audio(audio_path):
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
np.save('audio/your_filename.npy', whisper_feature)
|
||||
|
||||
def mask_face(image):
|
||||
# Load dlib's face detector and the landmark predictor
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
|
||||
predictor = dlib.shape_predictor(predictor_path)
|
||||
|
||||
# Load your input image
|
||||
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
|
||||
# image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError("Image not found or unable to load.")
|
||||
|
||||
# Convert to grayscale for detection
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Detect faces in the image
|
||||
faces = detector(gray)
|
||||
|
||||
# Process each detected face
|
||||
for face in faces:
|
||||
# Predict landmarks
|
||||
landmarks = predictor(gray, face)
|
||||
|
||||
# The indices of nose landmarks are 27 to 35
|
||||
nose_tip = landmarks.part(33).y
|
||||
|
||||
# Blacken the region below the nose tip
|
||||
blacken_area = image[nose_tip:, :]
|
||||
blacken_area[:] = (0, 0, 0)
|
||||
|
||||
# Save the final image or display it
|
||||
# cv2.imwrite("output_image.jpg", image)
|
||||
return image
|
||||
@@ -0,0 +1,182 @@
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
unet = accelerator.prepare(
|
||||
unet,
|
||||
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if not (args.unet_checkpoint == None):
|
||||
print("unet ckpt loaded")
|
||||
accelerator.load_state(args.unet_checkpoint)
|
||||
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
# shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
+66
-181
@@ -1,203 +1,111 @@
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import WhisperModel
|
||||
import sys
|
||||
import copy
|
||||
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
# Configure ffmpeg path
|
||||
if not fast_check_ffmpeg():
|
||||
print("Adding ffmpeg to PATH")
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Set computing device
|
||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||
# Load model weights
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=args.unet_model_path,
|
||||
vae_type=args.vae_type,
|
||||
unet_config=args.unet_config,
|
||||
device=device
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
# Convert models to half precision if float16 is enabled
|
||||
if args.use_float16:
|
||||
global pe
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# Initialize face parser with configurable parameters based on version
|
||||
if args.version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
else: # v1
|
||||
fp = FaceParsing()
|
||||
|
||||
# Load inference configuration
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print("Loaded inference config:", inference_config)
|
||||
|
||||
# Process each task
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
try:
|
||||
# Get task configuration
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
if "result_name" in inference_config[task_id]:
|
||||
args.output_vid_name = inference_config[task_id]["result_name"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
|
||||
# Set bbox_shift based on version
|
||||
if args.version == "v15":
|
||||
bbox_shift = 0 # v15 uses fixed bbox_shift
|
||||
else:
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
|
||||
|
||||
# Set output paths
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
# Create temporary directories
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save paths
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
# Set output video paths
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
|
||||
|
||||
# Extract frames from source video
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path]
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path):
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
# Preprocess input images
|
||||
#print(input_img_list)
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("Using saved coordinates")
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("Extracting landmarks... time-consuming operation")
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
print(f"Number of frames: {len(frame_list)}")
|
||||
|
||||
# Process each frame
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# Smooth first and last frames
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
# Batch inference
|
||||
print("Starting inference")
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
total = int(np.ceil(float(video_num) / batch_size))
|
||||
|
||||
# Execute inference
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
@@ -205,72 +113,49 @@ def main(args):
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
# Pad generated images to original video size
|
||||
print("Padding generated images to original video size")
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
# print(bbox)
|
||||
continue
|
||||
|
||||
# Merge results with version-specific parameters
|
||||
if args.version == "v15":
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
else:
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
# Save prediction results
|
||||
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
||||
print("Video generation command:", cmd_img2video)
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
print(cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
||||
print("Audio combination command:", cmd_combine_audio)
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
print(cmd_combine_audio)
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove("temp.mp4")
|
||||
shutil.rmtree(result_img_save_path)
|
||||
os.remove(temp_vid_path)
|
||||
|
||||
shutil.rmtree(save_dir_full)
|
||||
if not args.saved_coord:
|
||||
os.remove(crop_coord_save_path)
|
||||
|
||||
print(f"Results saved to {output_vid_name}")
|
||||
except Exception as e:
|
||||
print("Error occurred during processing:", e)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,334 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
from omegaconf import OmegaConf
|
||||
from typing import Tuple, List, Union
|
||||
import decord
|
||||
import json
|
||||
import cv2
|
||||
from musetalk.utils.face_detection import FaceAlignment,LandmarksType
|
||||
from mmpose.apis import inference_topdown, init_model
|
||||
from mmpose.structures import merge_data_samples
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Adding ffmpeg to PATH")
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
class AnalyzeFace:
|
||||
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
||||
"""
|
||||
Initialize the AnalyzeFace class with the given device, config file, and checkpoint file.
|
||||
|
||||
Parameters:
|
||||
device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu').
|
||||
config_file (str): Path to the mmpose model configuration file.
|
||||
checkpoint_file (str): Path to the mmpose model checkpoint file.
|
||||
"""
|
||||
self.device = device
|
||||
self.dwpose = init_model(config_file, checkpoint_file, device=self.device)
|
||||
self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device)
|
||||
|
||||
def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
Detect faces and keypoints in the given image.
|
||||
|
||||
Parameters:
|
||||
im (np.ndarray): The input image.
|
||||
maxface (bool): Whether to detect the maximum face. Default is True.
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints.
|
||||
"""
|
||||
try:
|
||||
# Ensure the input image has the correct shape
|
||||
if im.ndim == 3:
|
||||
im = np.expand_dims(im, axis=0)
|
||||
elif im.ndim != 4 or im.shape[0] != 1:
|
||||
raise ValueError("Input image must have shape (1, H, W, C)")
|
||||
|
||||
bbox = self.facedet.get_detections_for_batch(np.asarray(im))
|
||||
results = inference_topdown(self.dwpose, np.asarray(im)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark= keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
return face_land_mark, bbox
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during face analysis: {e}")
|
||||
return np.array([]),[]
|
||||
|
||||
def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
||||
|
||||
"""
|
||||
Convert video files to a specified format and save them to the destination path.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the converted video files will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for idx, vid in enumerate(vid_list):
|
||||
if vid.endswith('.mp4'):
|
||||
org_vid_path = os.path.join(org_path, vid)
|
||||
dst_vid_path = os.path.join(dst_path, vid)
|
||||
|
||||
if org_vid_path != dst_vid_path:
|
||||
cmd = [
|
||||
"ffmpeg", "-hide_banner", "-y", "-i", org_vid_path,
|
||||
"-r", "25", "-crf", "15", "-c:v", "libx264",
|
||||
"-pix_fmt", "yuv420p", dst_vid_path
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
if idx % 1000 == 0:
|
||||
print(f"### {idx} videos converted ###")
|
||||
|
||||
def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None:
|
||||
"""
|
||||
Segment video files into smaller clips of specified duration.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the segmented video files will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
segment_duration (int): The duration of each segment in seconds. Default is 30 seconds.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for idx, vid in enumerate(vid_list):
|
||||
if vid.endswith('.mp4'):
|
||||
input_file = os.path.join(org_path, vid)
|
||||
original_filename = os.path.basename(input_file)
|
||||
|
||||
command = [
|
||||
'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0',
|
||||
'-segment_time', str(segment_duration), '-f', 'segment',
|
||||
'-reset_timestamps', '1',
|
||||
os.path.join(dst_path, f'clip%03d_{original_filename}')
|
||||
]
|
||||
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
||||
"""
|
||||
Extract audio from video files and save as WAV format.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the extracted audio files will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for idx, vid in enumerate(vid_list):
|
||||
if vid.endswith('.mp4'):
|
||||
video_path = os.path.join(org_path, vid)
|
||||
audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav")
|
||||
try:
|
||||
command = [
|
||||
'ffmpeg', '-hide_banner', '-y', '-i', video_path,
|
||||
'-vn', '-acodec', 'pcm_s16le', '-f', 'wav',
|
||||
'-ar', '16000', '-ac', '1', audio_output_path,
|
||||
]
|
||||
|
||||
subprocess.run(command, check=True)
|
||||
print(f"Audio saved to: {audio_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error extracting audio from {vid}: {e}")
|
||||
|
||||
def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]):
|
||||
"""
|
||||
Split video files into training and validation sets based on val_list_hdtf.
|
||||
|
||||
Parameters:
|
||||
video_files (List[str]): A list of video file names.
|
||||
val_list_hdtf (List[str]): A list of validation file identifiers.
|
||||
|
||||
Returns:
|
||||
(List[str], List[str]): A tuple containing the training and validation file lists.
|
||||
"""
|
||||
val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)]
|
||||
train_files = [f for f in video_files if f not in val_files]
|
||||
return train_files, val_files
|
||||
|
||||
def save_list_to_file(file_path: str, data_list: List[str]) -> None:
|
||||
"""
|
||||
Save a list of strings to a file, each string on a new line.
|
||||
|
||||
Parameters:
|
||||
file_path (str): The path to the file where the list will be saved.
|
||||
data_list (List[str]): The list of strings to save.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with open(file_path, 'w') as file:
|
||||
for item in data_list:
|
||||
file.write(f"{item}\n")
|
||||
|
||||
def generate_train_list(cfg):
|
||||
train_file_path = cfg.video_clip_file_list_train
|
||||
val_file_path = cfg.video_clip_file_list_val
|
||||
val_list_hdtf = cfg.val_list_hdtf
|
||||
|
||||
meta_list = os.listdir(cfg.meta_root)
|
||||
|
||||
sorted_meta_list = sorted(meta_list)
|
||||
train_files, val_files = split_data(meta_list, val_list_hdtf)
|
||||
|
||||
save_list_to_file(train_file_path, train_files)
|
||||
save_list_to_file(val_file_path, val_files)
|
||||
|
||||
print(val_list_hdtf)
|
||||
|
||||
def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
||||
"""
|
||||
Convert video files to a specified format and save them to the destination path.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the meta json will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
|
||||
|
||||
analyze_face = AnalyzeFace(device, config_file, checkpoint_file)
|
||||
|
||||
for vid in tqdm(vid_list, desc="Processing videos"):
|
||||
#vid = "clip005_WDA_BernieSanders_000.mp4"
|
||||
#print(vid)
|
||||
if vid.endswith('.mp4'):
|
||||
vid_path = os.path.join(org_path, vid)
|
||||
wav_path = vid_path.replace(".mp4",".wav")
|
||||
vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json")
|
||||
if os.path.exists(vid_meta):
|
||||
continue
|
||||
print('process video {}'.format(vid))
|
||||
|
||||
total_bbox_list = []
|
||||
total_pts_list = []
|
||||
isvalid = True
|
||||
|
||||
# process
|
||||
try:
|
||||
cap = decord.VideoReader(vid_path, fault_tol=1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
total_frames = len(cap)
|
||||
for frame_idx in range(total_frames):
|
||||
frame = cap[frame_idx]
|
||||
if frame_idx==0:
|
||||
video_height,video_width,_ = frame.shape
|
||||
frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB)
|
||||
pts_list, bbox_list = analyze_face(frame_bgr)
|
||||
|
||||
if len(bbox_list)>0 and None not in bbox_list:
|
||||
bbox = bbox_list[0]
|
||||
else:
|
||||
isvalid = False
|
||||
bbox = []
|
||||
print(f"set isvalid to False as broken img in {frame_idx} of {vid}")
|
||||
break
|
||||
|
||||
#print(pts_list)
|
||||
if len(pts_list)>0 and pts_list is not None:
|
||||
pts = pts_list.tolist()
|
||||
else:
|
||||
isvalid = False
|
||||
pts = []
|
||||
break
|
||||
|
||||
if frame_idx==0:
|
||||
x1,y1,x2,y2 = bbox
|
||||
face_height, face_width = y2-y1,x2-x1
|
||||
|
||||
total_pts_list.append(pts)
|
||||
total_bbox_list.append(bbox)
|
||||
|
||||
meta_data = {
|
||||
"mp4_path": vid_path,
|
||||
"wav_path": wav_path,
|
||||
"video_size": [video_height, video_width],
|
||||
"face_size": [face_height, face_width],
|
||||
"frames": total_frames,
|
||||
"face_list": total_bbox_list,
|
||||
"landmark_list": total_pts_list,
|
||||
"isvalid":isvalid,
|
||||
}
|
||||
with open(vid_meta, 'w') as f:
|
||||
json.dump(meta_data, f, indent=4)
|
||||
|
||||
|
||||
|
||||
def main(cfg):
|
||||
# Ensure all necessary directories exist
|
||||
os.makedirs(cfg.video_root_25fps, exist_ok=True)
|
||||
os.makedirs(cfg.video_audio_clip_root, exist_ok=True)
|
||||
os.makedirs(cfg.meta_root, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True)
|
||||
|
||||
vid_list = os.listdir(cfg.video_root_raw)
|
||||
sorted_vid_list = sorted(vid_list)
|
||||
|
||||
# Save video file list
|
||||
with open(cfg.video_file_list, 'w') as file:
|
||||
for vid in sorted_vid_list:
|
||||
file.write(vid + '\n')
|
||||
|
||||
# 1. Convert videos to 25 FPS
|
||||
convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list)
|
||||
|
||||
# 2. Segment videos into 30-second clips
|
||||
segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second)
|
||||
|
||||
# 3. Extract audio
|
||||
clip_vid_list = os.listdir(cfg.video_audio_clip_root)
|
||||
extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list)
|
||||
|
||||
# 4. Generate video metadata
|
||||
analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list)
|
||||
|
||||
# 5. Generate training and validation set lists
|
||||
generate_train_list(cfg)
|
||||
print("done")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml")
|
||||
args = parser.parse_args()
|
||||
config = OmegaConf.load(args.config)
|
||||
|
||||
main(config)
|
||||
|
||||
+49
-122
@@ -10,29 +10,24 @@ import sys
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import json
|
||||
from transformers import WhisperModel
|
||||
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.utils import datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
|
||||
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
|
||||
from musetalk.utils.utils import load_all_model
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
|
||||
import shutil
|
||||
|
||||
import threading
|
||||
import queue
|
||||
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
||||
cap = cv2.VideoCapture(vid_path)
|
||||
@@ -47,7 +42,6 @@ def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def osmakedirs(path_list):
|
||||
for path in path_list:
|
||||
os.makedirs(path) if not os.path.exists(path) else None
|
||||
@@ -59,13 +53,7 @@ class Avatar:
|
||||
self.avatar_id = avatar_id
|
||||
self.video_path = video_path
|
||||
self.bbox_shift = bbox_shift
|
||||
# 根据版本设置不同的基础路径
|
||||
if args.version == "v15":
|
||||
self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
|
||||
else: # v1
|
||||
self.base_path = f"./results/avatars/{avatar_id}"
|
||||
|
||||
self.avatar_path = self.base_path
|
||||
self.avatar_path = f"./results/avatars/{avatar_id}"
|
||||
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||
self.latents_out_path= f"{self.avatar_path}/latents.pt"
|
||||
@@ -76,8 +64,7 @@ class Avatar:
|
||||
self.avatar_info = {
|
||||
"avatar_id":avatar_id,
|
||||
"video_path":video_path,
|
||||
"bbox_shift": bbox_shift,
|
||||
"version": args.version
|
||||
"bbox_shift":bbox_shift
|
||||
}
|
||||
self.preparation = preparation
|
||||
self.batch_size = batch_size
|
||||
@@ -172,10 +159,6 @@ class Avatar:
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(resized_crop_frame)
|
||||
@@ -190,13 +173,8 @@ class Avatar:
|
||||
for i,frame in enumerate(tqdm(self.frame_list_cycle)):
|
||||
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame)
|
||||
|
||||
x1, y1, x2, y2 = self.coord_list_cycle[i]
|
||||
if args.version == "v15":
|
||||
mode = args.parsing_mode
|
||||
else:
|
||||
mode = "raw"
|
||||
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
|
||||
|
||||
face_box = self.coord_list_cycle[i]
|
||||
mask,crop_box = get_image_prepare_material(frame,face_box)
|
||||
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask)
|
||||
self.mask_coords_list_cycle += [crop_box]
|
||||
self.mask_list_cycle.append(mask)
|
||||
@@ -208,8 +186,12 @@ class Avatar:
|
||||
pickle.dump(self.coord_list_cycle, f)
|
||||
|
||||
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
||||
#
|
||||
|
||||
def process_frames(self, res_frame_queue, video_len, skip_save_images):
|
||||
def process_frames(self,
|
||||
res_frame_queue,
|
||||
video_len,
|
||||
skip_save_images):
|
||||
print(video_len)
|
||||
while True:
|
||||
if self.idx>=video_len-1:
|
||||
@@ -229,35 +211,30 @@ class Avatar:
|
||||
continue
|
||||
mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
|
||||
mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
|
||||
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||
|
||||
if skip_save_images is False:
|
||||
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
|
||||
self.idx = self.idx + 1
|
||||
|
||||
def inference(self, audio_path, out_vid_name, fps, skip_save_images):
|
||||
def inference(self,
|
||||
audio_path,
|
||||
out_vid_name,
|
||||
fps,
|
||||
skip_save_images):
|
||||
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
|
||||
print("start inference")
|
||||
############################################## extract audio feature ##############################################
|
||||
start_time = time.time()
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
res_frame_queue = queue.Queue()
|
||||
self.idx = 0
|
||||
# Create a sub-thread and start it
|
||||
# # Create a sub-thread and start it
|
||||
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
||||
process_thread.start()
|
||||
|
||||
@@ -268,13 +245,15 @@ class Avatar:
|
||||
res_frame_list = []
|
||||
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
|
||||
audio_feature_batch = pe(whisper_batch.to(device))
|
||||
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype)
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature_batch).sample
|
||||
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_queue.put(res_frame)
|
||||
@@ -292,7 +271,7 @@ class Avatar:
|
||||
|
||||
if out_vid_name is not None and args.skip_save_images is False:
|
||||
# optional
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||
print(cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
@@ -313,27 +292,18 @@ if __name__ == "__main__":
|
||||
'''
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
parser.add_argument("--inference_config",
|
||||
type=str,
|
||||
default="configs/inference/realtime.yaml",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type=int,
|
||||
default=25,
|
||||
)
|
||||
parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument("--skip_save_images",
|
||||
action="store_true",
|
||||
help="Whether skip saving images for better generation speed calculation",
|
||||
@@ -341,56 +311,13 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure ffmpeg path
|
||||
if not fast_check_ffmpeg():
|
||||
print("Adding ffmpeg to PATH")
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Set computing device
|
||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model weights
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=args.unet_model_path,
|
||||
vae_type=args.vae_type,
|
||||
unet_config=args.unet_config,
|
||||
device=device
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
pe = pe.half().to(device)
|
||||
vae.vae = vae.vae.half().to(device)
|
||||
unet.model = unet.model.half().to(device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# Initialize face parser with configurable parameters based on version
|
||||
if args.version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
else: # v1
|
||||
fp = FaceParsing()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
|
||||
|
||||
for avatar_id in inference_config:
|
||||
data_preparation = inference_config[avatar_id]["preparation"]
|
||||
video_path = inference_config[avatar_id]["video_path"]
|
||||
if args.version == "v15":
|
||||
bbox_shift = 0
|
||||
else:
|
||||
bbox_shift = inference_config[avatar_id]["bbox_shift"]
|
||||
avatar = Avatar(
|
||||
avatar_id = avatar_id,
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
def test_ffmpeg(ffmpeg_path):
|
||||
print(f"Testing ffmpeg path: {ffmpeg_path}")
|
||||
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
|
||||
# Add ffmpeg path to environment variable
|
||||
os.environ["PATH"] = f"{ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
|
||||
try:
|
||||
# Try to run ffmpeg
|
||||
result = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
|
||||
print("FFmpeg test successful!")
|
||||
print("FFmpeg version information:")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except Exception as e:
|
||||
print("FFmpeg test failed!")
|
||||
print(f"Error message: {str(e)}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Default ffmpeg path, can be modified as needed
|
||||
default_path = r"ffmpeg-master-latest-win64-gpl-shared\bin"
|
||||
|
||||
# Use command line argument if provided, otherwise use default path
|
||||
ffmpeg_path = sys.argv[1] if len(sys.argv) > 1 else default_path
|
||||
|
||||
test_ffmpeg(ffmpeg_path)
|
||||
@@ -1,580 +0,0 @@
|
||||
import argparse
|
||||
import diffusers
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
import warnings
|
||||
import random
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import LoggerType
|
||||
from accelerate import InitProcessGroupKwargs
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from diffusers.utils import check_min_version
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from musetalk.utils.utils import (
|
||||
delete_additional_ckpt,
|
||||
seed_everything,
|
||||
get_mouth_region,
|
||||
process_audio_features,
|
||||
save_models
|
||||
)
|
||||
from musetalk.loss.basic_loss import set_requires_grad
|
||||
from musetalk.loss.syncnet import get_sync_loss
|
||||
from musetalk.utils.training_utils import (
|
||||
initialize_models_and_optimizers,
|
||||
initialize_dataloaders,
|
||||
initialize_loss_functions,
|
||||
initialize_syncnet,
|
||||
initialize_vgg,
|
||||
validation
|
||||
)
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
warnings.filterwarnings("ignore")
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
def main(cfg):
|
||||
exp_name = cfg.exp_name
|
||||
save_dir = f"{cfg.output_dir}/{exp_name}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
kwargs = DistributedDataParallelKwargs()
|
||||
process_group_kwargs = InitProcessGroupKwargs(
|
||||
timeout=timedelta(seconds=5400))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
||||
log_with=["tensorboard", LoggerType.TENSORBOARD],
|
||||
project_dir=os.path.join(save_dir, "./tensorboard"),
|
||||
kwargs_handlers=[kwargs, process_group_kwargs],
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if cfg.seed is not None:
|
||||
print('cfg.seed', cfg.seed, accelerator.process_index)
|
||||
seed_everything(cfg.seed + accelerator.process_index)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
|
||||
model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype)
|
||||
dataloader_dict = initialize_dataloaders(cfg)
|
||||
loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps'])
|
||||
syncnet = initialize_syncnet(cfg, accelerator, weight_dtype)
|
||||
vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare(
|
||||
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader']
|
||||
)
|
||||
print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader']))
|
||||
|
||||
# Calculate training steps and epochs
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(
|
||||
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
||||
)
|
||||
|
||||
# Initialize trackers on the main process
|
||||
if accelerator.is_main_process:
|
||||
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
||||
accelerator.init_trackers(
|
||||
cfg.exp_name,
|
||||
init_kwargs={"mlflow": {"run_name": run_time}},
|
||||
)
|
||||
|
||||
# Calculate total batch size
|
||||
total_batch_size = (
|
||||
cfg.data.train_bs
|
||||
* accelerator.num_processes
|
||||
* cfg.solver.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
# Log training information
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f"Num Epochs = {num_train_epochs}")
|
||||
logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}")
|
||||
logger.info(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
|
||||
logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}")
|
||||
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Load checkpoint if resuming training
|
||||
if cfg.resume_from_checkpoint:
|
||||
resume_dir = save_dir
|
||||
dirs = os.listdir(resume_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
if len(dirs) > 0:
|
||||
path = dirs[-1]
|
||||
accelerator.load_state(os.path.join(resume_dir, path))
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
global_step = int(path.split("-")[1])
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = global_step % num_update_steps_per_epoch
|
||||
|
||||
# Initialize progress bar
|
||||
progress_bar = tqdm(
|
||||
range(global_step, cfg.solver.max_train_steps),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# Log model types
|
||||
print("log type of models")
|
||||
print("unet", model_dict['unet'].dtype)
|
||||
print("vae", model_dict['vae'].dtype)
|
||||
print("wav2vec", model_dict['wav2vec'].dtype)
|
||||
|
||||
def get_ganloss_weight(step):
|
||||
"""Calculate GAN loss weight based on training step"""
|
||||
if step < cfg.discriminator_train_params.start_gan:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
# Training loop
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
# Set models to training mode
|
||||
model_dict['unet'].train()
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
loss_dict['discriminator'].train()
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
loss_dict['mouth_discriminator'].train()
|
||||
|
||||
# Initialize loss accumulators
|
||||
train_loss = 0.0
|
||||
train_loss_D = 0.0
|
||||
train_loss_D_mouth = 0.0
|
||||
l1_loss_accum = 0.0
|
||||
vgg_loss_accum = 0.0
|
||||
gan_loss_accum = 0.0
|
||||
gan_loss_accum_mouth = 0.0
|
||||
fm_loss_accum = 0.0
|
||||
sync_loss_accum = 0.0
|
||||
adapted_weight_accum = 0.0
|
||||
|
||||
t_data_start = time.time()
|
||||
for step, batch in enumerate(dataloader_dict['train_dataloader']):
|
||||
t_data = time.time() - t_data_start
|
||||
t_model_start = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
# Process input data
|
||||
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
||||
accelerator.device,
|
||||
non_blocking=True
|
||||
)
|
||||
bsz, num_frames, c, h, w = pixel_values.shape
|
||||
|
||||
# Process reference images
|
||||
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
||||
accelerator.device,
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
# Get face mask for GAN
|
||||
pixel_values_face_mask = batch['pixel_values_face_mask']
|
||||
|
||||
# Process audio features
|
||||
audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype)
|
||||
|
||||
# Initialize adapted weight
|
||||
adapted_weight = 1
|
||||
|
||||
# Process sync loss if enabled
|
||||
if cfg.loss_params.sync_loss > 0:
|
||||
mels = batch['mel']
|
||||
# Prepare frames for latentsync (combine channels and frames)
|
||||
gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w')
|
||||
# Use lower half of face for latentsync
|
||||
height = gt_frames.shape[2]
|
||||
gt_frames = gt_frames[:, :, height // 2:, :]
|
||||
|
||||
# Get audio embeddings
|
||||
audio_embed = syncnet.get_audio_embed(mels)
|
||||
|
||||
# Calculate adapted weight based on audio-visual similarity
|
||||
if cfg.use_adapted_weight:
|
||||
vision_embed_gt = syncnet.get_vision_embed(gt_frames)
|
||||
image_audio_sim_gt = F.cosine_similarity(
|
||||
audio_embed,
|
||||
vision_embed_gt,
|
||||
dim=1
|
||||
)[0]
|
||||
|
||||
if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65:
|
||||
if cfg.adapted_weight_type == "cut_off":
|
||||
adapted_weight = 0.0 # Skip this batch
|
||||
print(
|
||||
f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.")
|
||||
elif cfg.adapted_weight_type == "linear":
|
||||
adapted_weight = image_audio_sim_gt
|
||||
else:
|
||||
print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}")
|
||||
adapted_weight = 1
|
||||
|
||||
# Random frame selection for memory efficiency
|
||||
max_start = 16 - cfg.num_backward_frames
|
||||
frames_left_index = random.randint(0, max_start) if max_start > 0 else 0
|
||||
frames_right_index = frames_left_index + cfg.num_backward_frames
|
||||
else:
|
||||
frames_left_index = 0
|
||||
frames_right_index = cfg.data.n_sample_frames
|
||||
|
||||
# Extract frames for backward pass
|
||||
pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...]
|
||||
ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...]
|
||||
pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...]
|
||||
audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...]
|
||||
|
||||
# Encode target images
|
||||
frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w')
|
||||
latents = model_dict['vae'].encode(frames).latent_dist.mode()
|
||||
latents = latents * model_dict['vae'].config.scaling_factor
|
||||
latents = latents.float()
|
||||
|
||||
# Create masked images
|
||||
masked_pixel_values = pixel_values_backward.clone()
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode()
|
||||
masked_latents = masked_latents * model_dict['vae'].config.scaling_factor
|
||||
masked_latents = masked_latents.float()
|
||||
|
||||
# Encode reference images
|
||||
ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w')
|
||||
ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode()
|
||||
ref_latents = ref_latents * model_dict['vae'].config.scaling_factor
|
||||
ref_latents = ref_latents.float()
|
||||
|
||||
# Prepare face mask and audio features
|
||||
pixel_values_face_mask_backward = rearrange(
|
||||
pixel_values_face_mask_backward,
|
||||
"b f c h w -> (b f) c h w"
|
||||
)
|
||||
audio_prompts_backward = rearrange(
|
||||
audio_prompts_backward,
|
||||
'b f c h w-> (b f) c h w'
|
||||
)
|
||||
audio_prompts_backward = rearrange(
|
||||
audio_prompts_backward,
|
||||
'(b f) c h w -> (b f) (c h) w',
|
||||
b=bsz
|
||||
)
|
||||
|
||||
# Apply reference dropout (currently inactive)
|
||||
dropout = nn.Dropout(p=cfg.ref_dropout_rate)
|
||||
ref_latents = dropout(ref_latents)
|
||||
|
||||
# Prepare model inputs
|
||||
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
input_latents = input_latents.to(weight_dtype)
|
||||
timesteps = torch.tensor([0], device=input_latents.device)
|
||||
|
||||
# Forward pass
|
||||
latents_pred = model_dict['net'](
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts_backward,
|
||||
)
|
||||
latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred
|
||||
image_pred = model_dict['vae'].decode(latents_pred).sample
|
||||
|
||||
# Convert to float
|
||||
image_pred = image_pred.float()
|
||||
frames = frames.float()
|
||||
|
||||
# Calculate L1 loss
|
||||
l1_loss = loss_dict['L1_loss'](frames, image_pred)
|
||||
l1_loss_accum += l1_loss.item()
|
||||
loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight
|
||||
|
||||
# Process mouth GAN loss if enabled
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
frames_mouth, image_pred_mouth = get_mouth_region(
|
||||
frames,
|
||||
image_pred,
|
||||
pixel_values_face_mask_backward
|
||||
)
|
||||
pyramide_real_mouth = pyramid(downsampler(frames_mouth))
|
||||
pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth))
|
||||
|
||||
# Process VGG loss if enabled
|
||||
if cfg.loss_params.vgg_loss > 0:
|
||||
pyramide_real = pyramid(downsampler(frames))
|
||||
pyramide_generated = pyramid(downsampler(image_pred))
|
||||
|
||||
loss_IN = 0
|
||||
for scale in cfg.loss_params.pyramid_scale:
|
||||
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
||||
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
||||
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
||||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||
loss_IN += weight * value
|
||||
loss_IN /= sum(cfg.loss_params.vgg_layer_weight)
|
||||
loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight
|
||||
vgg_loss_accum += loss_IN.item()
|
||||
|
||||
# Process GAN loss if enabled
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
set_requires_grad(loss_dict['discriminator'], False)
|
||||
loss_G = 0.
|
||||
discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated)
|
||||
discriminator_maps_real = loss_dict['discriminator'](pyramide_real)
|
||||
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
gan_loss_accum += loss_G.item()
|
||||
|
||||
loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
||||
|
||||
# Process feature matching loss if enabled
|
||||
if cfg.loss_params.fm_loss[0] > 0:
|
||||
L_feature_matching = 0.
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'feature_maps_%s' % scale
|
||||
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
||||
value = torch.abs(a - b).mean()
|
||||
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
||||
loss += L_feature_matching * adapted_weight
|
||||
fm_loss_accum += L_feature_matching.item()
|
||||
|
||||
# Process mouth GAN loss if enabled
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
set_requires_grad(loss_dict['mouth_discriminator'], False)
|
||||
loss_G = 0.
|
||||
mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth)
|
||||
mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth)
|
||||
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
gan_loss_accum_mouth += loss_G.item()
|
||||
|
||||
loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
||||
|
||||
# Process feature matching loss for mouth if enabled
|
||||
if cfg.loss_params.fm_loss[0] > 0:
|
||||
L_feature_matching = 0.
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'feature_maps_%s' % scale
|
||||
for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])):
|
||||
value = torch.abs(a - b).mean()
|
||||
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
||||
loss += L_feature_matching * adapted_weight
|
||||
fm_loss_accum += L_feature_matching.item()
|
||||
|
||||
# Process sync loss if enabled
|
||||
if cfg.loss_params.sync_loss > 0:
|
||||
pred_frames = rearrange(
|
||||
image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1])
|
||||
pred_frames = pred_frames[:, :, height // 2 :, :]
|
||||
sync_loss, image_audio_sim_pred = get_sync_loss(
|
||||
audio_embed,
|
||||
gt_frames,
|
||||
pred_frames,
|
||||
syncnet,
|
||||
adapted_weight,
|
||||
frames_left_index=frames_left_index,
|
||||
frames_right_index=frames_right_index,
|
||||
)
|
||||
sync_loss_accum += sync_loss.item()
|
||||
loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight
|
||||
|
||||
# Backward pass
|
||||
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
|
||||
train_loss += avg_loss.item()
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Train discriminator if GAN loss is enabled
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
set_requires_grad(loss_dict['discriminator'], True)
|
||||
loss_D = loss_dict['discriminator_full'](frames, image_pred.detach())
|
||||
avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean()
|
||||
train_loss_D += avg_loss_D.item() / 1
|
||||
loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight
|
||||
accelerator.backward(loss_D)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(
|
||||
loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm)
|
||||
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
||||
loss_dict['optimizer_D'].step()
|
||||
loss_dict['scheduler_D'].step()
|
||||
loss_dict['optimizer_D'].zero_grad()
|
||||
|
||||
# Train mouth discriminator if mouth GAN loss is enabled
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
set_requires_grad(loss_dict['mouth_discriminator'], True)
|
||||
mouth_loss_D = loss_dict['mouth_discriminator_full'](
|
||||
frames_mouth, image_pred_mouth.detach())
|
||||
avg_mouth_loss_D = accelerator.gather(
|
||||
mouth_loss_D.repeat(cfg.data.train_bs)).mean()
|
||||
train_loss_D_mouth += avg_mouth_loss_D.item() / 1
|
||||
mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight
|
||||
accelerator.backward(mouth_loss_D)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(
|
||||
loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm)
|
||||
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
||||
loss_dict['mouth_optimizer_D'].step()
|
||||
loss_dict['mouth_scheduler_D'].step()
|
||||
loss_dict['mouth_optimizer_D'].zero_grad()
|
||||
|
||||
# Update main model
|
||||
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(
|
||||
model_dict['trainable_params'],
|
||||
cfg.solver.max_grad_norm,
|
||||
)
|
||||
model_dict['optimizer'].step()
|
||||
model_dict['lr_scheduler'].step()
|
||||
model_dict['optimizer'].zero_grad()
|
||||
|
||||
# Update progress and log metrics
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({
|
||||
"train_loss": train_loss,
|
||||
"train_loss_D": train_loss_D,
|
||||
"train_loss_D_mouth": train_loss_D_mouth,
|
||||
"l1_loss": l1_loss_accum,
|
||||
"vgg_loss": vgg_loss_accum,
|
||||
"gan_loss": gan_loss_accum,
|
||||
"fm_loss": fm_loss_accum,
|
||||
"sync_loss": sync_loss_accum,
|
||||
"adapted_weight": adapted_weight_accum,
|
||||
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
||||
}, step=global_step)
|
||||
|
||||
# Reset loss accumulators
|
||||
train_loss = 0.0
|
||||
l1_loss_accum = 0.0
|
||||
vgg_loss_accum = 0.0
|
||||
gan_loss_accum = 0.0
|
||||
fm_loss_accum = 0.0
|
||||
sync_loss_accum = 0.0
|
||||
adapted_weight_accum = 0.0
|
||||
train_loss_D = 0.0
|
||||
train_loss_D_mouth = 0.0
|
||||
|
||||
# Run validation if needed
|
||||
if global_step % cfg.val_freq == 0 or global_step == 10:
|
||||
try:
|
||||
validation(
|
||||
cfg,
|
||||
dataloader_dict['val_dataloader'],
|
||||
model_dict['net'],
|
||||
model_dict['vae'],
|
||||
model_dict['wav2vec'],
|
||||
accelerator,
|
||||
save_dir,
|
||||
global_step,
|
||||
weight_dtype,
|
||||
syncnet_score=adapted_weight,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during validation: {e}")
|
||||
|
||||
# Save checkpoint if needed
|
||||
if global_step % cfg.checkpointing_steps == 0:
|
||||
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
|
||||
try:
|
||||
start_time = time.time()
|
||||
if accelerator.is_main_process:
|
||||
save_models(
|
||||
accelerator,
|
||||
model_dict['net'],
|
||||
save_dir,
|
||||
global_step,
|
||||
cfg,
|
||||
logger=logger
|
||||
)
|
||||
delete_additional_ckpt(save_dir, cfg.total_limit)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 300:
|
||||
print(f"Skipping storage as it took too long in step {global_step}.")
|
||||
else:
|
||||
print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.")
|
||||
except Exception as e:
|
||||
print(f"Error when saving model in step {global_step}:", e)
|
||||
|
||||
# Update progress bar
|
||||
t_model = time.time() - t_model_start
|
||||
logs = {
|
||||
"step_loss": loss.detach().item(),
|
||||
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
||||
"td": f"{t_data:.2f}s",
|
||||
"tm": f"{t_model:.2f}s",
|
||||
}
|
||||
t_data_start = time.time()
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= cfg.solver.max_train_steps:
|
||||
break
|
||||
|
||||
# Save model after each epoch
|
||||
if (epoch + 1) % cfg.save_model_epoch_interval == 0:
|
||||
try:
|
||||
start_time = time.time()
|
||||
if accelerator.is_main_process:
|
||||
save_models(accelerator, model_dict['net'], save_dir, global_step, cfg)
|
||||
accelerator.save_state(save_path)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 120:
|
||||
print(f"Skipping storage as it took too long in step {global_step}.")
|
||||
else:
|
||||
print(f"Model saved successfully in {elapsed_time}s.")
|
||||
except Exception as e:
|
||||
print(f"Error when saving model in step {global_step}:", e)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# End training
|
||||
accelerator.end_training()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
|
||||
args = parser.parse_args()
|
||||
config = OmegaConf.load(args.config)
|
||||
main(config)
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MuseTalk Training Script
|
||||
# This script combines both training stages for the MuseTalk model
|
||||
# Usage: sh train.sh [stage1|stage2]
|
||||
# Example: sh train.sh stage1 # To run stage 1 training
|
||||
# Example: sh train.sh stage2 # To run stage 2 training
|
||||
|
||||
# Check if stage argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Error: Please specify the training stage"
|
||||
echo "Usage: ./train.sh [stage1|stage2]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
STAGE=$1
|
||||
|
||||
# Validate stage argument
|
||||
if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then
|
||||
echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Launch distributed training using accelerate
|
||||
# --config_file: Path to the GPU configuration file
|
||||
# --main_process_port: Port number for the main process, used for distributed training communication
|
||||
# train.py: Training script
|
||||
# --config: Path to the training configuration file
|
||||
echo "Starting $STAGE training..."
|
||||
accelerate launch --config_file ./configs/training/gpu.yaml \
|
||||
--main_process_port 29502 \
|
||||
train.py --config ./configs/training/$STAGE.yaml
|
||||
|
||||
echo "Training completed for $STAGE"
|
||||
@@ -0,0 +1,235 @@
|
||||
import os, random, cv2, argparse
|
||||
import torch
|
||||
from torch.utils import data as data_utils
|
||||
from os.path import dirname, join, basename, isfile
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from utils.utils import prepare_mask_and_masked_image
|
||||
import torchvision.utils as vutils
|
||||
import torchvision.transforms as transforms
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import heapq
|
||||
|
||||
syncnet_T = 1
|
||||
RESIZED_IMG = 256
|
||||
|
||||
connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线
|
||||
(17, 18), (18, 19), (19, 20), (20, 21), #左眉毛
|
||||
(22, 23), (23, 24), (24, 25), (25, 26), #右眉毛
|
||||
(27, 28),(28,29),(29,30),# 鼻梁
|
||||
(31,32),(32,33),(33,34),(34,35), #鼻子
|
||||
(36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼
|
||||
(42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼
|
||||
(48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延
|
||||
(54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延
|
||||
(60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈
|
||||
]
|
||||
|
||||
|
||||
def get_image_list(data_root, split):
|
||||
filelist = []
|
||||
imgNumList = []
|
||||
with open('filelists/{}.txt'.format(split)) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if ' ' in line:
|
||||
filename = line.split()[0]
|
||||
imgNum = int(line.split()[1])
|
||||
filelist.append(os.path.join(data_root, filename))
|
||||
imgNumList.append(imgNum)
|
||||
return filelist, imgNumList
|
||||
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
json_path,
|
||||
use_audio_length_left=1,
|
||||
use_audio_length_right=1,
|
||||
whisper_model_type = "tiny"
|
||||
):
|
||||
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
# self.split = split
|
||||
self.img_names_path = '../data'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
self.use_audio_length_right = use_audio_length_right
|
||||
|
||||
if self.whisper_model_type =="tiny":
|
||||
self.whisper_path = '../data/audios'
|
||||
self.whisper_feature_W = 5
|
||||
self.whisper_feature_H = 384
|
||||
elif self.whisper_model_type =="largeV2":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_feature_W = 33
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
with open(json_path, 'r') as file:
|
||||
self.all_videos = json.load(file)
|
||||
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
if not os.path.exists(json_path_names):
|
||||
img_names = glob(join(vidname, '*.png'))
|
||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||
with open(json_path_names, "w") as f:
|
||||
json.dump(img_names,f)
|
||||
else:
|
||||
with open(json_path_names, "r") as f:
|
||||
img_names = json.load(f)
|
||||
self.all_img_names.append(img_names)
|
||||
|
||||
def get_frame_id(self, frame):
|
||||
return int(basename(frame).split('.')[0])
|
||||
|
||||
def get_window(self, start_frame):
|
||||
start_id = self.get_frame_id(start_frame)
|
||||
vidname = dirname(start_frame)
|
||||
|
||||
window_fnames = []
|
||||
for frame_id in range(start_id, start_id + syncnet_T):
|
||||
frame = join(vidname, '{}.png'.format(frame_id))
|
||||
if not isfile(frame):
|
||||
return None
|
||||
window_fnames.append(frame)
|
||||
return window_fnames
|
||||
|
||||
def read_window(self, window_fnames):
|
||||
if window_fnames is None: return None
|
||||
window = []
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
if img is None:
|
||||
return None
|
||||
try:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG))
|
||||
except Exception as e:
|
||||
print("read_window has error fname not exist:",fname)
|
||||
return None
|
||||
|
||||
window.append(img)
|
||||
|
||||
return window
|
||||
|
||||
def prepare_window(self, window):
|
||||
# 1 x H x W x 3
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_videos)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
while 1:
|
||||
idx = random.randint(0, len(self.all_videos) - 1)
|
||||
#随机选择某个video里
|
||||
vidname = self.all_videos[idx].split('/')[-1]
|
||||
video_imgs = self.all_img_names[idx]
|
||||
if len(video_imgs) <= 11:
|
||||
continue
|
||||
img_name = random.choice(video_imgs)
|
||||
img_idx = int(basename(img_name).split(".")[0])
|
||||
random_element = random.randint(0,len(video_imgs)-1)
|
||||
while abs(random_element - img_idx) <= 5:
|
||||
random_element = random.randint(0,len(video_imgs)-1)
|
||||
img_dir = os.path.dirname(img_name)
|
||||
ref_image = os.path.join(img_dir, f"{str(random_element)}.png")
|
||||
target_window_fnames = self.get_window(img_name)
|
||||
ref_window_fnames = self.get_window(ref_image)
|
||||
|
||||
if target_window_fnames is None or ref_window_fnames is None:
|
||||
print("No such img",img_name, ref_image)
|
||||
continue
|
||||
|
||||
try:
|
||||
#构建目标img数据
|
||||
target_window = self.read_window(target_window_fnames)
|
||||
if target_window is None :
|
||||
print("No such target window,",target_window_fnames)
|
||||
continue
|
||||
#构建参考img数据
|
||||
ref_window = self.read_window(ref_window_fnames)
|
||||
|
||||
if ref_window is None:
|
||||
print("No such target ref window,",ref_window)
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"发生未知错误:{e}")
|
||||
continue
|
||||
|
||||
#构建target输入
|
||||
target_window = self.prepare_window(target_window)
|
||||
image = gt = target_window.copy().squeeze()
|
||||
target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1:输入
|
||||
ref_image = self.prepare_window(ref_window).squeeze()
|
||||
|
||||
|
||||
|
||||
mask = torch.zeros((ref_image.shape[1], ref_image.shape[2]))
|
||||
mask[:ref_image.shape[2]//2,:] = 1
|
||||
image = torch.FloatTensor(image)
|
||||
mask, masked_image = prepare_mask_and_masked_image(image,mask)
|
||||
|
||||
|
||||
|
||||
#音频特征
|
||||
window_index = self.get_frame_id(img_name)
|
||||
sub_folder_name = vidname.split('/')[-1]
|
||||
|
||||
## 根据window_index加载相邻的音频
|
||||
audio_feature_all = []
|
||||
is_index_out_of_range = False
|
||||
if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)):
|
||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||
# 判定是否越界
|
||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||
if not os.path.exists(audio_feat_path):
|
||||
is_index_out_of_range = True
|
||||
break
|
||||
|
||||
try:
|
||||
audio_feature_all.append(np.load(audio_feat_path))
|
||||
except Exception as e:
|
||||
print(f"发生未知错误:{e}")
|
||||
print(f"npy load error {audio_feat_path}")
|
||||
if is_index_out_of_range:
|
||||
continue
|
||||
audio_feature = np.concatenate(audio_feature_all, axis=0)
|
||||
else:
|
||||
continue
|
||||
|
||||
audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1, -1, 384
|
||||
if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384
|
||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||
continue
|
||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||
return ref_image, image, masked_image, mask, audio_feature
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_root = '...'
|
||||
val_data = Dataset(data_root,
|
||||
'val',
|
||||
use_audio_length_left = 2,
|
||||
use_audio_length_right = 2,
|
||||
whisper_model_type = "tiny"
|
||||
)
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_data, batch_size=4, shuffle=True,
|
||||
num_workers=1)
|
||||
|
||||
for i, data in enumerate(val_data_loader):
|
||||
ref_image, image, masked_image, mask, audio_feature = data
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Data preprocessing
|
||||
|
||||
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
|
||||
The train yaml file should contain the training video paths and corresponding audio paths
|
||||
The test yaml file should contain the validation video paths and corresponding audio paths
|
||||
|
||||
Run:
|
||||
```
|
||||
./data_new.sh train output train_video1.mp4 train_video2.mp4
|
||||
./data_new.sh test output test_video1.mp4 test_video2.mp4
|
||||
```
|
||||
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
|
||||
|
||||
## Data organization
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
├── audios
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
```
|
||||
|
||||
## Training
|
||||
Simply run after preparing the preprocessed data
|
||||
```
|
||||
cd train_codes
|
||||
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
|
||||
#--val_json="../val.json" \
|
||||
```
|
||||
## Inference with trained checkpoit
|
||||
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||
```
|
||||
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [x] release data preprocessing codes
|
||||
- [ ] release some novel designs in training (after technical report)
|
||||
@@ -0,0 +1,322 @@
|
||||
RD_Radio10_000 3501
|
||||
RD_Radio11_000 752
|
||||
RD_Radio11_001 1502
|
||||
RD_Radio12_000 876
|
||||
RD_Radio13_000 877
|
||||
RD_Radio14_000 1251
|
||||
RD_Radio16_000 2377
|
||||
RD_Radio17_000 2176
|
||||
RD_Radio18_000 1827
|
||||
RD_Radio19_000 1876
|
||||
RD_Radio1_000 1876
|
||||
RD_Radio20_000 276
|
||||
RD_Radio21_000 1201
|
||||
RD_Radio22_000 752
|
||||
RD_Radio23_000 1602
|
||||
RD_Radio25_000 2126
|
||||
RD_Radio26_000 1052
|
||||
RD_Radio27_000 1376
|
||||
RD_Radio28_000 2252
|
||||
RD_Radio29_000 2252
|
||||
RD_Radio2_000 2076
|
||||
RD_Radio30_000 2877
|
||||
RD_Radio31_000 1377
|
||||
RD_Radio32_000 1502
|
||||
RD_Radio33_000 2252
|
||||
RD_Radio34_000 702
|
||||
RD_Radio34_001 252
|
||||
RD_Radio34_002 501
|
||||
RD_Radio34_003 326
|
||||
RD_Radio34_004 502
|
||||
RD_Radio34_005 502
|
||||
RD_Radio34_006 502
|
||||
RD_Radio34_007 252
|
||||
RD_Radio34_008 876
|
||||
RD_Radio34_009 752
|
||||
RD_Radio35_000 2127
|
||||
RD_Radio36_000 2377
|
||||
RD_Radio37_000 3127
|
||||
RD_Radio38_000 3752
|
||||
RD_Radio39_000 1502
|
||||
RD_Radio3_000 2377
|
||||
RD_Radio40_000 1252
|
||||
RD_Radio41_000 2127
|
||||
RD_Radio42_000 1752
|
||||
RD_Radio43_000 1502
|
||||
RD_Radio44_000 1127
|
||||
RD_Radio45_000 1628
|
||||
RD_Radio46_000 1877
|
||||
RD_Radio47_000 1502
|
||||
RD_Radio48_000 2127
|
||||
RD_Radio49_000 1502
|
||||
RD_Radio4_000 1002
|
||||
RD_Radio50_000 2252
|
||||
RD_Radio51_000 3253
|
||||
RD_Radio52_000 2752
|
||||
RD_Radio53_000 2627
|
||||
RD_Radio54_000 577
|
||||
RD_Radio56_000 1952
|
||||
RD_Radio57_000 3077
|
||||
RD_Radio59_000 1952
|
||||
RD_Radio5_000 1377
|
||||
RD_Radio7_000 2252
|
||||
RD_Radio8_000 2126
|
||||
RD_Radio9_000 1502
|
||||
WDA_AdamSchiff_000 6877
|
||||
WDA_AdamSmith_000 7327
|
||||
WDA_AlexandriaOcasioCortez_000 2252
|
||||
WDA_AmyKlobuchar0_000 6501
|
||||
WDA_AmyKlobuchar1_000 3253
|
||||
WDA_AmyKlobuchar1_001 877
|
||||
WDA_AmyKlobuchar1_002 1502
|
||||
WDA_AmyKlobuchar1_003 1377
|
||||
WDA_AndyKim_000 6952
|
||||
WDA_AndyLevin_000 4876
|
||||
WDA_AnnieKuster_000 4876
|
||||
WDA_BarackObama_000 3575
|
||||
WDA_BarackObama_001 5625
|
||||
WDA_BarbaraLee0_000 6502
|
||||
WDA_BarbaraLee1_000 4702
|
||||
WDA_BenCardin0_000 8127
|
||||
WDA_BenCardin1_000 7102
|
||||
WDA_BenRayLujn_000 7127
|
||||
WDA_BennieThompson1_000 6375
|
||||
WDA_BennieThompson_000 5375
|
||||
WDA_BernieSanders_000 8451
|
||||
WDA_BettyMcCollum_000 7002
|
||||
WDA_BillPascrell_000 7252
|
||||
WDA_BillRichardson_000 3127
|
||||
WDA_BobCasey0_000 10627
|
||||
WDA_BobCasey1_000 252
|
||||
WDA_BobMenendez_000 3002
|
||||
WDA_BobbyScott_000 2002
|
||||
WDA_BradSchneider_000 6627
|
||||
WDA_BrendaLawrence_000 5627
|
||||
WDA_BrianSchatz0_000 502
|
||||
WDA_BrianSchatz1_000 2627
|
||||
WDA_BrianSchatz2_000 1952
|
||||
WDA_ByronDorgan1_000 4277
|
||||
WDA_CarolynMaloney1_000 8377
|
||||
WDA_CatherineCortezMasto_000 2827
|
||||
WDA_CedricRichmond_000 7250
|
||||
WDA_ChrisCoons1_000 4452
|
||||
WDA_ChrisCoons_000 6827
|
||||
WDA_ChrisMurphy0_000 6877
|
||||
WDA_ChrisMurphy1_000 6377
|
||||
WDA_ChrisVanHollen0_000 8202
|
||||
WDA_ChrisVanHollen1_000 7327
|
||||
WDA_ChuckSchumer0_000 5577
|
||||
WDA_ChuckSchumer1_000 4827
|
||||
WDA_ColinAllred_000 4751
|
||||
WDA_DanKildee1_000 6252
|
||||
WDA_DanKildee_000 2502
|
||||
WDA_DavidCicilline_000 5875
|
||||
WDA_DebHaaland_000 5876
|
||||
WDA_DebbieDingell0_000 7752
|
||||
WDA_DebbieDingell1_000 7251
|
||||
WDA_DebbieStabenow0_000 4577
|
||||
WDA_DebbieStabenow1_000 5201
|
||||
WDA_DebbieWassermanSchultz_000 7625
|
||||
WDA_DianaDeGette0_000 6002
|
||||
WDA_DianaDeGette1_000 2202
|
||||
WDA_DianneFeinstein_000 7453
|
||||
WDA_DickDurbin_000 6326
|
||||
WDA_DonaldMcEachin_000 7627
|
||||
WDA_DonnaShalala1_000 7500
|
||||
WDA_DougJones_000 6077
|
||||
WDA_EdMarkey0_000 4752
|
||||
WDA_EdMarkey1_000 6501
|
||||
WDA_ElijahCummings_000 6377
|
||||
WDA_EliotEngel_000 6876
|
||||
WDA_EmanuelCleaver_000 7001
|
||||
WDA_EricSwalwell_000 6377
|
||||
WDA_FrankPallone0_000 5625
|
||||
WDA_FrankPallone1_000 6502
|
||||
WDA_GerryConnolly_000 6752
|
||||
WDA_HakeemJeffries_000 6002
|
||||
WDA_HaleyStevens_000 5001
|
||||
WDA_HenryWaxman_000 1125
|
||||
WDA_HillaryClinton_000 2500
|
||||
WDA_JackReed0_000 2377
|
||||
WDA_JackReed1_000 4877
|
||||
WDA_JackieSpeier_000 4625
|
||||
WDA_JackyRosen1_000 10077
|
||||
WDA_JackyRosen_000 5502
|
||||
WDA_JamesClyburn1_000 6876
|
||||
WDA_JamesClyburn_000 6875
|
||||
WDA_JanSchakowsky0_000 4128
|
||||
WDA_JanSchakowsky1_000 6251
|
||||
WDA_JeanneShaheen0_000 6702
|
||||
WDA_JeanneShaheen1_000 6577
|
||||
WDA_JeffMerkley1_000 6952
|
||||
WDA_JerryNadler_000 8377
|
||||
WDA_JimHimes_000 7000
|
||||
WDA_JimmyGomez_000 4627
|
||||
WDA_JoaquinCastro_000 5126
|
||||
WDA_JoeCrowley0_000 4877
|
||||
WDA_JoeCrowley1_000 1127
|
||||
WDA_JoeCrowley1_001 627
|
||||
WDA_JoeCrowley1_002 502
|
||||
WDA_JoeCrowley1_003 752
|
||||
WDA_JoeDonnelly_000 1377
|
||||
WDA_JoeKennedy_000 1327
|
||||
WDA_JoeManchin_000 3377
|
||||
WDA_JoeNeguse_000 1628
|
||||
WDA_JoeNeguse_001 1126
|
||||
WDA_JoeNeguse_002 1251
|
||||
WDA_JohnLewis0_000 6252
|
||||
WDA_JohnLewis1_000 7252
|
||||
WDA_JohnSarbanes0_000 6377
|
||||
WDA_JohnSarbanes1_000 1877
|
||||
WDA_JohnYarmuth1_000 5377
|
||||
WDA_JonTester0_000 3252
|
||||
WDA_JonTester1_000 5327
|
||||
WDA_KarenBass_000 7126
|
||||
WDA_KatherineClark_000 1552
|
||||
WDA_KathyCastor1_000 6001
|
||||
WDA_KathyCastor_000 2252
|
||||
WDA_KimSchrier_000 6877
|
||||
WDA_KirstenGillibrand_000 9627
|
||||
WDA_LaurenUnderwood_000 9877
|
||||
WDA_LisaBluntRochester_000 4500
|
||||
WDA_LloydDoggett0_000 7377
|
||||
WDA_LloydDoggett1_000 2252
|
||||
WDA_LucilleRoybal-Allard_000 2877
|
||||
WDA_LucyMcBath_000 4000
|
||||
WDA_MarciaFudge_000 7377
|
||||
WDA_MarkWarner1_000 377
|
||||
WDA_MarkWarner1_001 1327
|
||||
WDA_MarkWarner2_000 377
|
||||
WDA_MarkWarner_000 3127
|
||||
WDA_MartinHeinrich_000 5202
|
||||
WDA_MattCartwright_000 5377
|
||||
WDA_MazieHirono0_000 3752
|
||||
WDA_MichelleLujanGrisham_000 6875
|
||||
WDA_MichelleObama_000 2000
|
||||
WDA_MikeDoyle_000 8750
|
||||
WDA_MikeThompson0_000 4625
|
||||
WDA_MikeThompson1_000 1827
|
||||
WDA_NancyPelosi0_000 10251
|
||||
WDA_NancyPelosi1_000 1377
|
||||
WDA_NancyPelosi3_000 1127
|
||||
WDA_NitaLowey_000 5876
|
||||
WDA_NydiaVelzquez_000 5500
|
||||
WDA_PatrickLeahy0_000 6951
|
||||
WDA_PatrickLeahy1_000 9953
|
||||
WDA_PattyMurray0_000 3627
|
||||
WDA_PattyMurray1_000 5702
|
||||
WDA_PeterDeFazio_000 6628
|
||||
WDA_RaulRuiz_000 5752
|
||||
WDA_RichardBlumenthal_000 7827
|
||||
WDA_RichardNeal0_000 6252
|
||||
WDA_RichardNeal1_000 6877
|
||||
WDA_RobinKelly_000 4377
|
||||
WDA_RonWyden0_000 5152
|
||||
WDA_RonWyden1_000 8327
|
||||
WDA_ScottPeters0_000 7002
|
||||
WDA_ScottPeters1_000 3952
|
||||
WDA_SeanCasten_000 6126
|
||||
WDA_SeanPatrickMaloney_000 6502
|
||||
WDA_SheldonWhitehouse0_000 6327
|
||||
WDA_SheldonWhitehouse1_000 5702
|
||||
WDA_SherrodBrown0_000 7452
|
||||
WDA_SherrodBrown1_000 7327
|
||||
WDA_StenyHoyer_000 3502
|
||||
WDA_StephanieMurphy_000 4000
|
||||
WDA_SuzanDelBene_000 6875
|
||||
WDA_TammyBaldwin0_000 5327
|
||||
WDA_TammyBaldwin1_000 2503
|
||||
WDA_TammyDuckworth_000 5702
|
||||
WDA_TedLieu_000 6877
|
||||
WDA_TerriSewell0_000 2127
|
||||
WDA_TerriSewell1_000 10952
|
||||
WDA_TerriSewell_000 6752
|
||||
WDA_TimWalz_000 6628
|
||||
WDA_TinaSmith_000 5576
|
||||
WDA_TomCarper_000 4701
|
||||
WDA_TomPerez_000 5500
|
||||
WDA_TomUdall_000 4077
|
||||
WDA_VeronicaEscobar0_000 4377
|
||||
WDA_VeronicaEscobar1_000 2453
|
||||
WDA_WhipJimClyburn_000 6876
|
||||
WDA_XavierBecerra_000 877
|
||||
WDA_XavierBecerra_001 627
|
||||
WDA_XavierBecerra_002 1377
|
||||
WDA_ZoeLofgren_000 4625
|
||||
WRA_AdamKinzinger0_000 4251
|
||||
WRA_AdamKinzinger1_000 2751
|
||||
WRA_AdamKinzinger2_000 1002
|
||||
WRA_AdamKinzinger2_001 1001
|
||||
WRA_AdamKinzinger2_002 502
|
||||
WRA_AllenWest_000 6876
|
||||
WRA_AnnMarieBuerkle_000 452
|
||||
WRA_AnnWagner_000 3127
|
||||
WRA_AustinScott0_000 1177
|
||||
WRA_BillCassidy0_000 427
|
||||
WRA_BobCorker_000 1127
|
||||
WRA_BobGoodlatte0_000 1001
|
||||
WRA_BobGoodlatte0_001 752
|
||||
WRA_BobGoodlatte0_002 876
|
||||
WRA_BobGoodlatte0_003 1002
|
||||
WRA_BobbySchilling_001 1000
|
||||
WRA_BobbySchilling_002 1001
|
||||
WRA_CandiceMiller0_000 3001
|
||||
WRA_CarlyFiorina0_000 876
|
||||
WRA_CarlyFiorina_000 1902
|
||||
WRA_CathyMcMorrisRodgers0_000 6077
|
||||
WRA_CathyMcMorrisRodgers1_000 7375
|
||||
WRA_CathyMcMorrisRodgers1_001 7500
|
||||
WRA_CathyMcMorrisRodgers1_002 7500
|
||||
WRA_CathyMcMorrisRodgers2_000 2751
|
||||
WRA_ChuckGrassley_000 4502
|
||||
WRA_CoryGardner0_000 8577
|
||||
WRA_CoryGardner1_000 2002
|
||||
WRA_CoryGardner_000 1252
|
||||
WRA_DanNewhouse_000 3627
|
||||
WRA_DanSullivan_000 6877
|
||||
WRA_DaveCamp_000 752
|
||||
WRA_DaveCamp_001 876
|
||||
WRA_DaveCamp_002 627
|
||||
WRA_DavidVitter_000 2877
|
||||
WRA_DeanHeller_000 377
|
||||
WRA_DebFischer0_000 7577
|
||||
WRA_DebFischer1_000 1577
|
||||
WRA_DebFischer2_000 5202
|
||||
WRA_DianeBlack0_000 451
|
||||
WRA_DianeBlack0_001 301
|
||||
WRA_DianeBlack1_000 250
|
||||
WRA_DuncanHunter_000 727
|
||||
WRA_EricCantor_000 2952
|
||||
WRA_ErikPaulsen_000 876
|
||||
WRA_ErikPaulsen_001 377
|
||||
WRA_ErikPaulsen_002 627
|
||||
WRA_ErikPaulsen_003 752
|
||||
WRA_FredUpton_000 1701
|
||||
WRA_GeoffDavis_000 250
|
||||
WRA_GeorgeLeMieux_000 752
|
||||
WRA_GeorgeLeMieux_001 1001
|
||||
WRA_GregWalden1_000 752
|
||||
WRA_GregWalden1_001 1127
|
||||
WRA_GregWalden1_002 1327
|
||||
WRA_GregWalden_000 377
|
||||
WRA_JaimeHerreraBeutler0_000 952
|
||||
WRA_JebHensarling0_001 1176
|
||||
WRA_JebHensarling1_000 2327
|
||||
WRA_JebHensarling1_001 727
|
||||
WRA_JebHensarling2_000 2302
|
||||
WRA_JebHensarling2_001 877
|
||||
WRA_JebHensarling2_002 2127
|
||||
WRA_JebHensarling2_003 1877
|
||||
WRA_JeffFlake_000 7202
|
||||
WRA_JeffFlake_001 7502
|
||||
WRA_JeffFlake_002 7377
|
||||
WRA_JimInhofe_000 3127
|
||||
WRA_JimRisch_000 1377
|
||||
WRA_JoeHeck1_000 250
|
||||
WRA_JoePitts_000 1077
|
||||
WRA_JohnBarrasso0_000 5077
|
||||
WRA_JohnBarrasso1_000 3452
|
||||
WRA_JohnBoehner0_000 1626
|
||||
WRA_JohnBoehner1_000 2951
|
||||
WRA_JohnBoozman_000 1877
|
||||
WRA_JohnHoeven_000 2577
|
||||
@@ -0,0 +1,30 @@
|
||||
WRA_JohnKasich0_000 1302
|
||||
WRA_JohnKasich1_000 1301
|
||||
WRA_JohnKasich1_001 1127
|
||||
WRA_JohnKasich3_000 1052
|
||||
WRA_JohnThune_000 1452
|
||||
WRA_JohnnyIsakson_000 5951
|
||||
WRA_JohnnyIsakson_001 5000
|
||||
WRA_JonKyl_000 626
|
||||
WRA_JoniErnst0_000 2077
|
||||
WRA_JoniErnst1_000 1452
|
||||
WRA_JuddGregg_000 1252
|
||||
WRA_JuddGregg_001 952
|
||||
WRA_JuddGregg_002 753
|
||||
WRA_KayBaileyHutchison_000 6325
|
||||
WRA_KellyAyotte_000 8077
|
||||
WRA_KevinBrady2_000 1276
|
||||
WRA_KevinBrady3_000 752
|
||||
WRA_KevinBrady_000 2503
|
||||
WRA_KevinMcCarthy0_000 1002
|
||||
WRA_KevinMcCarthy0_001 1127
|
||||
WRA_KristiNoem0_000 727
|
||||
WRA_KristiNoem1_000 452
|
||||
WRA_KristiNoem2_000 5952
|
||||
WRA_KristiNoem2_001 7277
|
||||
WRA_LamarAlexander0_000 1527
|
||||
WRA_LamarAlexander_000 1552
|
||||
WRA_LisaMurkowski0_000 1752
|
||||
WRA_LynnJenkins_000 877
|
||||
WRA_LynnJenkins_001 1002
|
||||
WRA_MarcoRubio_000 526
|
||||
@@ -0,0 +1,36 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.6.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"cross_attention_dim": 384,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 8,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 4,
|
||||
"sample_size": 64,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
]
|
||||
}
|
||||
Executable
+670
@@ -0,0 +1,670 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from PIL import Image, ImageDraw
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
|
||||
from DataLoader import Dataset
|
||||
from utils.utils import preprocess_img_tensor
|
||||
from torch.utils import data as data_utils
|
||||
from utils.model_utils import validation,PositionalEncoding
|
||||
import time
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--unet_config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="the configuration of unet file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reconstruction",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument("--testing_speed", action="store_true", help="Whether to caculate the running time")
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
|
||||
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Conduct validation every X updates."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_out_dir",
|
||||
type=str,
|
||||
default = '',
|
||||
help=(
|
||||
"Conduct validation every X updates."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_audio_length_left",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of audio length (left).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_audio_length_right",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of audio length (right)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper_model_type",
|
||||
type=str,
|
||||
default="landmark_nearest",
|
||||
choices=["tiny","largeV2"],
|
||||
help="Determine whisper feature type",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
def print_model_dtypes(model, model_name):
|
||||
for name, param in model.named_parameters():
|
||||
if(param.dtype!=torch.float32):
|
||||
print(f"{name}: {param.dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
args.output_dir = f"output/{args.output_dir}"
|
||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
os.makedirs(args.val_out_dir, exist_ok=True)
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
|
||||
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
project_config=project_config,
|
||||
)
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
# set_seed(args.seed)
|
||||
set_seed(seed + accelerator.process_index)
|
||||
|
||||
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
with open(args.unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
|
||||
#text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
|
||||
# Todo:
|
||||
print("Loading AutoencoderKL")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
|
||||
vae_fp32 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
print("Loading UNet2DConditionModel")
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if args.whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif args.whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
else:
|
||||
print(f"not support whisper_model_type {args.whisper_model_type}")
|
||||
|
||||
print("Loading models done...")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters()))
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
print("loading train_dataset ...")
|
||||
train_dataset = Dataset(args.data_root,
|
||||
args.train_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
train_data_loader = data_utils.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||
num_workers=8)
|
||||
print("loading val_dataset ...")
|
||||
val_dataset = Dataset(args.data_root,
|
||||
args.val_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_dataset, batch_size=1, shuffle=False,
|
||||
num_workers=8)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
unet, optimizer, train_data_loader, val_data_loader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_data_loader, val_data_loader,lr_scheduler
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
vae_fp32.requires_grad_(False)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
# weight_dtype = torch.float16
|
||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||
vae_fp32.encoder = None
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.decoder = None
|
||||
pe.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
print(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
# path="../models/pytorch_model.bin"
|
||||
#TODO change path
|
||||
# path=None
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# caluate the elapsed time
|
||||
elapsed_time = []
|
||||
start = time.time()
|
||||
|
||||
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
dataloader_time = time.time() - start
|
||||
start = time.time()
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
# """
|
||||
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
# print("ref_image: ",ref_image.shape)
|
||||
# print("masks: ", masks.shape)
|
||||
# print("masked_image: ", masked_image.shape)
|
||||
# print("audio feature: ", audio_feature.shape)
|
||||
# print("image: ", image.shape)
|
||||
# """
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
img_process_time = time.time() - start
|
||||
start = time.time()
|
||||
with accelerator.accumulate(unet):
|
||||
vae = vae.half()
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
vae_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
|
||||
|
||||
bsz = latents.shape[0]
|
||||
# fix timestep for each image
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
# concatenate the latents with the mask and the masked latents
|
||||
"""
|
||||
print("=============vae latents=====".format(epoch,step))
|
||||
print("ref_latents: ",ref_latents.shape)
|
||||
print("mask: ", mask.shape)
|
||||
print("masked_latents: ", masked_latents.shape)
|
||||
"""
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
audio_feature = pe(audio_feature)
|
||||
|
||||
# Predict the noise residual
|
||||
image_pred = unet(latent_model_input,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature).sample
|
||||
|
||||
if args.reconstruction: # decode the image from the predicted latents
|
||||
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
|
||||
image_pred_img = vae_fp32.decode(image_pred_img).sample
|
||||
|
||||
# Mask the top half of the image and calculate the loss only for the lower half of the image.
|
||||
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
|
||||
image = image[:, :, image.shape[2]//2:, :]
|
||||
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images
|
||||
loss_latents = F.l1_loss(image_pred.float(), latents.float(), reduction="mean") # the loss of the latents
|
||||
|
||||
loss = 2.0*loss_lip + loss_latents # add some weight to balance the loss
|
||||
|
||||
else:
|
||||
loss = F.mse_loss(image_pred.float(), latents.float(), reduction="mean")
|
||||
#
|
||||
|
||||
unet_elapsed_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters())
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
backward_elapsed_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
if args.testing_speed is True and accelerator.is_main_process:
|
||||
elapsed_time.append(
|
||||
[dataloader_time, unet_elapsed_time, vae_time, backward_elapsed_time,img_process_time]
|
||||
)
|
||||
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
if global_step % args.validation_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
logger.info(
|
||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||
)
|
||||
print("===========start validation==========")
|
||||
# Use the helper function to check the data types for each model
|
||||
vae_new = vae.float()
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
|
||||
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
|
||||
|
||||
print(f"weight_dtype: {weight_dtype}")
|
||||
print(f"epoch type: {type(epoch)}")
|
||||
print(f"global_step type: {type(global_step)}")
|
||||
validation(
|
||||
# vae=accelerator.unwrap_model(vae),
|
||||
vae=accelerator.unwrap_model(vae_new),
|
||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet_config=unet_config,
|
||||
# weight_dtype=weight_dtype,
|
||||
weight_dtype=torch.float32,
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
val_data_loader=val_data_loader,
|
||||
output_dir = args.val_out_dir,
|
||||
whisper_model_type = args.whisper_model_type
|
||||
)
|
||||
logger.info(f"Saved samples to images/val")
|
||||
start = time.time()
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0],
|
||||
"unet": unet_elapsed_time,
|
||||
"backward": backward_elapsed_time,
|
||||
"data": dataloader_time,
|
||||
"img_process":img_process_time,
|
||||
"vae":vae_time
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.log(
|
||||
{
|
||||
"loss/step_loss": logs["loss"],
|
||||
"parameter/lr": logs["lr"],
|
||||
"time/unet_forward_time": unet_elapsed_time,
|
||||
"time/unet_backward_time": backward_elapsed_time,
|
||||
"time/data_time": dataloader_time,
|
||||
"time/img_process_time":img_process_time,
|
||||
"time/vae_time": vae_time
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,29 @@
|
||||
export VAE_MODEL="../models/sd-vae-ft-mse/"
|
||||
export DATASET="../data"
|
||||
export UNET_CONFIG="../models/musetalk/musetalk.json"
|
||||
|
||||
accelerate launch train.py \
|
||||
--mixed_precision="fp16" \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
--data_root=$DATASET \
|
||||
--train_batch_size=256 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=100000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=2000 \
|
||||
--validation_steps=2000 \
|
||||
--reconstruction \
|
||||
--resume_from_checkpoint="latest" \
|
||||
--use_audio_length_left=2 \
|
||||
--use_audio_length_right=2 \
|
||||
--whisper_model_type="tiny" \
|
||||
--train_json="../train.json" \
|
||||
--val_json="../val.json" \
|
||||
--lr_scheduler="cosine" \
|
||||
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from utils.utils import decode_latents, preprocess_img_tensor
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
import logging
|
||||
import json
|
||||
|
||||
RESIZED_IMG = 256
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
Transformer 中的位置编码(positional encoding)
|
||||
"""
|
||||
def __init__(self, d_model=384, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
b, seq_len, d_model = x.size()
|
||||
pe = self.pe[:, :seq_len, :]
|
||||
#print(b, seq_len, d_model)
|
||||
x = x + pe.to(x.device)
|
||||
return x
|
||||
|
||||
def validation(vae: torch.nn.Module,
|
||||
vae_fp32: torch.nn.Module,
|
||||
unet:torch.nn.Module,
|
||||
unet_config,
|
||||
weight_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
val_data_loader,
|
||||
output_dir,
|
||||
whisper_model_type,
|
||||
UNet2DConditionModel=UNet2DConditionModel
|
||||
):
|
||||
|
||||
# Get the validation pipeline
|
||||
unet_copy = UNet2DConditionModel(**unet_config)
|
||||
|
||||
unet_copy.load_state_dict(unet.state_dict())
|
||||
unet_copy.to(vae.device).to(dtype=weight_dtype)
|
||||
unet_copy.eval()
|
||||
|
||||
if whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
elif whisper_model_type == "tiny-conv":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
|
||||
else:
|
||||
print(f"not support whisper_model_type {whisper_model_type}")
|
||||
pe.to(vae.device, dtype=weight_dtype)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
|
||||
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||||
|
||||
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
|
||||
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
|
||||
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
|
||||
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
|
||||
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
|
||||
|
||||
val_img_dir = f"images/{output_dir}/{global_step}"
|
||||
if not os.path.exists(val_img_dir):
|
||||
os.makedirs(val_img_dir)
|
||||
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
|
||||
|
||||
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
|
||||
|
||||
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
import os
|
||||
import cv2
|
||||
from glob import glob
|
||||
|
||||
|
||||
def preprocess_img_tensor(image_tensor):
|
||||
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
|
||||
N, C, H, W = image_tensor.shape
|
||||
# 计算新的宽度和高度,使其为 32 的整数倍
|
||||
new_w = W - W % 32
|
||||
new_h = H - H % 32
|
||||
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
|
||||
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
# 对每个图像应用变换,并将结果存储在一个新的张量中
|
||||
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
|
||||
for i in range(N):
|
||||
# 使用 F.interpolate 替换 transforms.Resize
|
||||
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
|
||||
preprocessed_images[i] = transform(resized_image.squeeze(0))
|
||||
|
||||
return preprocessed_images
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
|
||||
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
|
||||
# # 对图像张量进行归一化
|
||||
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
|
||||
# # 对遮罩张量进行归一化和二值化
|
||||
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
|
||||
mask_tensor[mask_tensor < 0.5] = 0
|
||||
mask_tensor[mask_tensor >= 0.5] = 1
|
||||
# 创建遮罩后的图像
|
||||
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
|
||||
|
||||
return mask_tensor, masked_image_tensor
|
||||
|
||||
|
||||
def encode_latents(vae, image):
|
||||
# init_image = preprocess_image(image)
|
||||
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
||||
init_latents = 0.18215 * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(vae, latents, ref_images=None):
|
||||
latents = (1/ 0.18215) * latents
|
||||
image = vae.decode(latents.to(vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
if ref_images is not None:
|
||||
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
ref_images = (ref_images * 255).round().astype("uint8")
|
||||
h = image.shape[1]
|
||||
image[:, :h//2] = ref_images[:, :h//2]
|
||||
image = [Image.fromarray(im) for im in image]
|
||||
|
||||
return image[0]
|
||||
|
||||
Reference in New Issue
Block a user