1ab53a626b
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
82 lines
3.3 KiB
Python
Executable File
82 lines
3.3 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, optim
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
|
|
import musetalk.loss.vgg_face as vgg_face
|
|
|
|
class Interpolate(nn.Module):
|
|
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
|
super(Interpolate, self).__init__()
|
|
self.size = size
|
|
self.scale_factor = scale_factor
|
|
self.mode = mode
|
|
self.align_corners = align_corners
|
|
|
|
def forward(self, input):
|
|
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
|
|
|
def set_requires_grad(net, requires_grad=False):
|
|
if net is not None:
|
|
for param in net.parameters():
|
|
param.requires_grad = requires_grad
|
|
|
|
if __name__ == "__main__":
|
|
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
pyramid_scale = [1, 0.5, 0.25, 0.125]
|
|
vgg_IN = vgg_face.Vgg19().to(device)
|
|
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
|
|
vgg_IN.eval()
|
|
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
|
|
|
|
image = torch.rand(8, 3, 256, 256).to(device)
|
|
image_pred = torch.rand(8, 3, 256, 256).to(device)
|
|
pyramide_real = pyramid(downsampler(image))
|
|
pyramide_generated = pyramid(downsampler(image_pred))
|
|
|
|
|
|
loss_IN = 0
|
|
for scale in cfg.loss_params.pyramid_scale:
|
|
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
|
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
|
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
|
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
|
loss_IN += weight * value
|
|
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
|
|
print(loss_IN)
|
|
|
|
#print(cfg.model_params.discriminator_params)
|
|
|
|
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
|
|
discriminator_full = DiscriminatorFullModel(discriminator)
|
|
disc_scales = cfg.model_params.discriminator_params.scales
|
|
# Prepare optimizer and loss function
|
|
optimizer_D = optim.AdamW(discriminator.parameters(),
|
|
lr=cfg.discriminator_train_params.lr,
|
|
weight_decay=cfg.discriminator_train_params.weight_decay,
|
|
betas=cfg.discriminator_train_params.betas,
|
|
eps=cfg.discriminator_train_params.eps)
|
|
scheduler_D = CosineAnnealingLR(optimizer_D,
|
|
T_max=cfg.discriminator_train_params.epochs,
|
|
eta_min=1e-6)
|
|
|
|
discriminator.train()
|
|
|
|
set_requires_grad(discriminator, False)
|
|
|
|
loss_G = 0.
|
|
discriminator_maps_generated = discriminator(pyramide_generated)
|
|
discriminator_maps_real = discriminator(pyramide_real)
|
|
|
|
for scale in disc_scales:
|
|
key = 'prediction_map_%s' % scale
|
|
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
|
loss_G += value
|
|
|
|
print(loss_G)
|