可以做c语言任务的网站,网页设计与制作网站教程,天津网站建设设计开发公司,建设门户网站多少钱240930_CycleGAN循环生成对抗网络 CycleGAN#xff0c;也算是笔者记录GAN生成对抗网络的第四篇#xff0c;前三篇可以跳转
240925-GAN生成对抗网络-CSDN博客
240929-DCGAN生成漫画头像-CSDN博客
240929-CGAN条件生成对抗网络-CSDN博客
在第三篇中#xff0c;我们采用了p…240930_CycleGAN循环生成对抗网络 CycleGAN也算是笔者记录GAN生成对抗网络的第四篇前三篇可以跳转
240925-GAN生成对抗网络-CSDN博客
240929-DCGAN生成漫画头像-CSDN博客
240929-CGAN条件生成对抗网络-CSDN博客
在第三篇中我们采用了pix2pix进行图像风格的转移但在pix2pix上训练往往需要在像素级上一一对应的数据就造成了很多方面任务无法完成有一定局限性。比如在绘画领域我们无法得到画家当时所画的那个场景的照片同样我们此刻拍的照片也不能请那些画家来给咱们对照着画一幅画。这就造成了数据集无法一一对应无法进行训练的问题。CycleGAN就是为了解决这样的问题上面的图片就是CycleGAN所实现的效果。简单来说就是网络上前段时间爆火的图像风格转移比如把你女朋友的照片传进去后变成一个公主。
传统GAN
在传统GAN中我们有一组生成对抗网络也就是两个网络生成器根据随机噪声生成图像传给判别器进行判断。 CycleGAN
而在CycleGAN中我们有两组生成对抗网络如下图所示。
加入X和Y是两个文件夹X中放了莫奈一个有名的画家所画的所有作品Y中放了你手机相册里的一些风景照。此时我们需要把X域中一张图通过G生成器生成一张符合Y域的图就是用一张油画生成一张照片风格转移Dy努力判别到底是真实的Y还是G生成器生成的假Y。G和Dy构成一组生成对抗网络其结果就是Dy再也判别不出到底是真Y还是假Y。
而第二组生成对抗网络就是把Y域中的一张图通过F生成器生成一张符合X域的图像照片转油画Dx努力判别是真的X还是F生成的假X这就构成了第二组生成对抗网络其结果是Dx再也分辨不出真的X和生成的X。
通过两组生成对抗网络就实现了莫奈风格画作和照片的互相转移也就构成了Cycle循环。 但这样仍然存在于一个问题像我们在CGAN中说的那样在CGAN中我们除了判断其是真图像还是假图像之外还要判断其是否符合我们提供的标签。
在这里我们就要判断其到底是不是和原图所描述的场景一致。即要做到“风格转变内容不变”。比如我们提供的油画是一幅森林的画作通过G生成器生成后确实生成了照片但是生成的照片却变成了城市这不是我们想要的我们想要的是转变为照片的森林。
也有一种可能是不管你输入森林还是城市的油画生成器总是给你生成一份草原的照片这也确实符合照片的风格但是也不是我们想要得到的这是一种模式崩溃现象。
循环一致性损失cycle-consistency loss
为了解决这个问题我们需要加入一个循环一致性损失cycle-consistency loss。具体该如何实现呢。我们就需要构建一个循环一致性损失在森林的油画转成照片之后我们再把这张照片通过F生成器转回油画然后与原图做L1范式逐元素做差取绝对值再求和。用来确定和原图尽可能相似。 以下是该损失的公式 简单作以公式剖析 F ( G ( x ) ) F(G(x)) F(G(x))就是“x通过G生成的图像再传给F生成得到的图像”然后减去x就是逐元素做差然后外面套了两个看着像绝对值的东西内层的两个竖线确实是取绝对值外层的两个竖线就不是了右下角还跟着一个1这就是取L1范式简单说就是上面说的逐元素做差取绝对值再求和。这个损失是越小越好。
Identity Loss(可选)
在CycleGAN中生成图不在意颜色的差别只要能骗过判别器就行生成出来的画作可能颜色就不太对少了点灵魂论文中提到可以加入Identity Loss来解决这个问题。 整体损失
整个CycleGAN的损失就是两个GAN的损失加上这个循环一致性损失 其中单独的GAN损失在之前讲GAN时就已经讲清楚了复习请跳转博客开头那个GAN的链接。
项目实战
接下来我们通过一个实战项目进行讲解具体参考代码在最后引出了代码部分就简单过一下注释都写得比较清楚。
数据集预处理
使用的数据集里面的图片来源于ImageNet该数据集共有17个数据包本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小其中用于训练的苹果图片996张、橘子图片1020张用于测试的苹果图片266张、橘子图片248张。
这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理为了将重点聚焦到模型此处将数据预处理后的结果转换为 MindRecord 格式的数据以省略大部分数据预处理的代码。
from download import downloadurl https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zipdownload(url, ., kindzip, replaceTrue)此处我们用MindDataset接口读取和处理数据集
from mindspore.dataset import MindDataset# 读取MindRecord格式数据
name_mr ./CycleGAN_apple2orange/apple2orange_train.mindrecord
data MindDataset(dataset_filesname_mr)
print(Datasize: , data.get_dataset_size())batch_size 1
dataset data.batch(batch_size)
datasize dataset.get_dataset_size()可视化
通过 create_dict_iterator 函数将数据转换成字典迭代器然后使用 matplotlib 模块可视化部分训练数据。
这部分都是常用的绘图代码所以注释没有写太多。
import numpy as np
import matplotlib.pyplot as pltmean 0.5 * 255
std 0.5 * 255plt.figure(figsize(12, 5), dpi60)
for i, data in enumerate(dataset.create_dict_iterator()):if i 5:show_images_a data[image_A].asnumpy()show_images_b data[image_B].asnumpy()plt.subplot(2, 5, i1)show_images_a (show_images_a[0] * std mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis(off)plt.subplot(2, 5, i6)show_images_b (show_images_b[0] * std mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis(off)else:break
plt.show()构建生成器
本案例生成器的模型结构参考的 ResNet 模型的结构参考原论文对于128×128大小的输入图片采用6个残差块相连图片大小为256×256以上的需要采用9个残差块相连所以本文网络有9个残差块相连超参数 n_layers 参数控制残差块数。
生成器的结构如下所示 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal# 初始化权重的标准差为0.02的正态分布
weight_init Normal(sigma0.02)class ConvNormReLU(nn.Cell):包含卷积、归一化及ReLU激活的模块。参数:input_channel (int): 输入通道数。out_planes (int): 输出通道数。kernel_size (int, 可选): 卷积核大小默认为4。stride (int, 可选): 步长默认为2。alpha (float, 可选): LeakyReLU的负斜率默认为0.2。norm_mode (str, 可选): 归一化模式可选instance或batch默认为instance。pad_mode (str, 可选): 填充模式可选CONSTANT或其他模式默认为CONSTANT。use_relu (bool, 可选): 是否使用ReLU默认为True。padding (int, 可选): 填充大小默认根据kernel_size计算。transpose (bool, 可选): 是否使用转置卷积默认为False。返回:Tensor: 经过卷积、归一化及ReLU后的输出张量。def __init__(self, input_channel, out_planes, kernel_size4, stride2, alpha0.2, norm_modeinstance,pad_modeCONSTANT, use_reluTrue, paddingNone, transposeFalse):super(ConvNormReLU, self).__init__()# 根据norm_mode选择不同的归一化层norm nn.BatchNorm2d(out_planes, affine(norm_mode ! instance))# 根据是否使用实例归一化来设置是否有偏置项has_bias (norm_mode instance)# 设置填充大小if padding is None:padding (kernel_size - 1) // 2# 根据pad_mode和transpose标志构建卷积层if pad_mode CONSTANT:conv nn.Conv2dTranspose if transpose else nn.Conv2dconv conv(input_channel, out_planes, kernel_size, stride, pad_modesame if transpose else pad,has_biashas_bias, weight_initweight_init)layers [conv, norm]else:paddings ((0, 0), (0, 0), (padding, padding), (padding, padding))pad nn.Pad(paddingspaddings, modepad_mode)conv nn.Conv2dTranspose if transpose else nn.Conv2dconv conv(input_channel, out_planes, kernel_size, stride, pad_modepad,has_biashas_bias, weight_initweight_init)layers [pad, conv, norm]# 添加ReLU层if use_relu:relu nn.ReLU() if alpha 0 else nn.LeakyReLU(alpha)layers.append(relu)self.features nn.SequentialCell(layers)def construct(self, x):构建并返回经过卷积、归一化及ReLU处理后的输出。参数:x (Tensor): 输入张量。返回:Tensor: 处理后的输出张量。output self.features(x)return outputclass ResidualBlock(nn.Cell):残差块包含两个ConvNormReLU模块和一个残差连接。参数:dim (int): 输入和输出的通道数。norm_mode (str, 可选): 归一化模式可选instance或batch默认为instance。dropout (bool, 可选): 是否使用Dropout默认为False。pad_mode (str, 可选): 填充模式可选CONSTANT或其他模式默认为CONSTANT。返回:Tensor: 经过残差连接后的输出张量。def __init__(self, dim, norm_modeinstance, dropoutFalse, pad_modeCONSTANT):super(ResidualBlock, self).__init__()self.conv1 ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_reluFalse)self.dropout nn.Dropout(p0.5) if dropout else Nonedef construct(self, x):构建并返回经过残差块处理后的输出。参数:x (Tensor): 输入张量。返回:Tensor: 处理后的输出张量。out self.conv1(x)if self.dropout:out self.dropout(out)out self.conv2(out)return x outclass ResNetGenerator(nn.Cell):基于ResNet架构的生成器网络。参数:input_channel (int, 可选): 输入通道数默认为3。output_channel (int, 可选): 初始输出通道数默认为64。n_layers (int, 可选): 残差块的数量默认为9。alpha (float, 可选): LeakyReLU的负斜率默认为0.2。norm_mode (str, 可选): 归一化模式可选instance或batch默认为instance。dropout (bool, 可选): 是否使用Dropout默认为False。pad_mode (str, 可选): 填充模式可选CONSTANT或其他模式默认为CONSTANT。返回:Tensor: 经过生成器处理后的输出张量。def __init__(self, input_channel3, output_channel64, n_layers9, alpha0.2, norm_modeinstance, dropoutFalse,pad_modeCONSTANT):super(ResNetGenerator, self).__init__()self.conv_in ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_modepad_mode)self.down_1 ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)layers [ResidualBlock(output_channel * 4, norm_mode, dropoutdropout, pad_modepad_mode) for _ in range(n_layers)]self.residuals nn.SequentialCell(layers)self.up_2 ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transposeTrue)self.up_1 ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transposeTrue)if pad_mode CONSTANT:self.conv_out nn.Conv2d(output_channel, 3, kernel_size7, stride1, pad_modepad,padding3, weight_initweight_init)else:pad nn.Pad(paddings((0, 0), (0, 0), (3, 3), (3, 3)), modepad_mode)conv nn.Conv2d(output_channel, 3, kernel_size7, stride1, pad_modepad, weight_initweight_init)self.conv_out nn.SequentialCell([pad, conv])def construct(self, x):构建并返回经过生成器处理后的输出。参数:x (Tensor): 输入张量。返回:Tensor: 处理后的输出张量。x self.conv_in(x)x self.down_1(x)x self.down_2(x)x self.residuals(x)x self.up_2(x)x self.up_1(x)output self.conv_out(x)return ops.tanh(output)# 实例化生成器
net_rg_a ResNetGenerator()
net_rg_a.update_parameters_name(net_rg_a.)net_rg_b ResNetGenerator()
net_rg_b.update_parameters_name(net_rg_b.)这个结构搭建的还是比较清晰的没有昨天看CGAN痛苦。这段执行完了之后我们可以直接把网络结构打印出来对照查看。
print(net_rg_a)打出来网络结构可能会很多其中ResidualBlock有好几层注意看ResNetGenerator方法 构建判别器
判别器其实是一个二分类网络模型输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理最后通过 Sigmoid 激活函数得到最终概率。
# 定义判别器类用于判断输入的图像是否真实
class Discriminator(nn.Cell):def __init__(self, input_channel3, output_channel64, n_layers3, alpha0.2, norm_modeinstance):初始化判别器。参数:input_channel (int): 输入图像的通道数默认为3。output_channel (int): 第一个卷积层的输出通道数默认为64。n_layers (int): 卷积层的数量默认为3。alpha (float): LeakyReLU激活函数的负斜率默认为0.2。norm_mode (str): 归一化模式默认为instance。判别器由多个卷积层、归一化层和LeakyReLU激活层组成。super(Discriminator, self).__init__()kernel_size 4# 第一层卷积和激活layers [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_modepad, padding1, weight_initweight_init),nn.LeakyReLU(alpha)]nf_mult output_channel# 中间层卷积、归一化和激活for i in range(1, n_layers):nf_mult_prev nf_multnf_mult min(2 ** i, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding1))# 最后一层卷积、归一化和激活注意步长为1nf_mult_prev nf_multnf_mult min(2 ** n_layers, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding1))# 输出层卷积输出通道数为1步长为1layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_modepad, padding1, weight_initweight_init))# 将所有层连接成一个序列模型self.features nn.SequentialCell(layers)def construct(self, x):前向传播函数。参数:x (Tensor): 输入的图像数据。返回:Tensor: 判别器的输出表示输入图像的真实性。output self.features(x)return output# 初始化两个判别器实例分别用于判别A域和B域的图像
net_d_a Discriminator()
net_d_a.update_parameters_name(net_d_a.)net_d_b Discriminator()
net_d_b.update_parameters_name(net_d_b.)优化器和损失函数
这里刚才也进行了讲解要注意的是每个网络的优化器都得单独定义。 # 构建生成器判别器优化器
optimizer_rg_a nn.Adam(net_rg_a.trainable_params(), learning_rate0.0002, beta10.5)
optimizer_rg_b nn.Adam(net_rg_b.trainable_params(), learning_rate0.0002, beta10.5)optimizer_d_a nn.Adam(net_d_a.trainable_params(), learning_rate0.0002, beta10.5)
optimizer_d_b nn.Adam(net_d_b.trainable_params(), learning_rate0.0002, beta10.5)# GAN网络损失函数这里最后一层不使用sigmoid函数
loss_fn nn.MSELoss(reductionmean)
l1_loss nn.L1Loss(mean)def gan_loss(predict, target):target ops.ones_like(predict) * targetloss loss_fn(predict, target)return loss
前向计算
为了减少模型振荡[1]这里遵循 Shrivastava 等人的策略[2]使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数保留了一个图像缓冲区用于存储生成器生成前的50个图像。
import mindspore as ms# 前向计算def generator(img_a, img_b):生成器函数用于生成假图像并对图像进行重建和身份转换测试。参数:img_a: Tensor, 输入图像A。img_b: Tensor, 输入图像B。返回:fake_a: Tensor, 生成的假图像A。fake_b: Tensor, 生成的假图像B。rec_a: Tensor, 重建后的图像A。rec_b: Tensor, 重建后的图像B。identity_a: Tensor, 图像A的身份转换结果。identity_b: Tensor, 图像B的身份转换结果。fake_a net_rg_b(img_b)fake_b net_rg_a(img_a)rec_a net_rg_b(fake_b)rec_b net_rg_a(fake_a)identity_a net_rg_b(img_a)identity_b net_rg_a(img_b)return fake_a, fake_b, rec_a, rec_b, identity_a, identity_blambda_a 10.0
lambda_b 10.0
lambda_idt 0.5def generator_forward(img_a, img_b):生成器的前向传播函数计算生成器的损失。参数:img_a: Tensor, 输入图像A。img_b: Tensor, 输入图像B。返回:fake_a: Tensor, 生成的假图像A。fake_b: Tensor, 生成的假图像B。loss_g: Tensor, 总生成器损失。loss_g_a: Tensor, 生成器A的对抗损失。loss_g_b: Tensor, 生成器B的对抗损失。loss_c_a: Tensor, 生成器A的循环一致性损失。loss_c_b: Tensor, 生成器B的循环一致性损失。loss_idt_a: Tensor, 生成器A的身份损失。loss_idt_b: Tensor, 生成器B的身份损失。true Tensor(True, dtype.bool_)fake_a, fake_b, rec_a, rec_b, identity_a, identity_b generator(img_a, img_b)loss_g_a gan_loss(net_d_b(fake_b), true)loss_g_b gan_loss(net_d_a(fake_a), true)loss_c_a l1_loss(rec_a, img_a) * lambda_aloss_c_b l1_loss(rec_b, img_b) * lambda_bloss_idt_a l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b l1_loss(identity_b, img_b) * lambda_b * lambda_idtloss_g loss_g_a loss_g_b loss_c_a loss_c_b loss_idt_a loss_idt_breturn fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_bdef generator_forward_grad(img_a, img_b):生成器前向传播的梯度计算函数。参数:img_a: Tensor, 输入图像A。img_b: Tensor, 输入图像B。返回:loss_g: Tensor, 总生成器损失的梯度。_, _, loss_g, _, _, _, _, _, _ generator_forward(img_a, img_b)return loss_gdef discriminator_forward(img_a, img_b, fake_a, fake_b):判别器的前向传播函数计算判别器的损失。参数:img_a: Tensor, 真实图像A。img_b: Tensor, 真实图像B。fake_a: Tensor, 生成的假图像A。fake_b: Tensor, 生成的假图像B。返回:loss_d: Tensor, 总判别器损失。false Tensor(False, dtype.bool_)true Tensor(True, dtype.bool_)d_fake_a net_d_a(fake_a)d_img_a net_d_a(img_a)d_fake_b net_d_b(fake_b)d_img_b net_d_b(img_b)loss_d_a gan_loss(d_fake_a, false) gan_loss(d_img_a, true)loss_d_b gan_loss(d_fake_b, false) gan_loss(d_img_b, true)loss_d (loss_d_a loss_d_b) * 0.5return loss_ddef discriminator_forward_a(img_a, fake_a):判别器A的前向传播函数计算判别器A的损失。参数:img_a: Tensor, 真实图像A。fake_a: Tensor, 生成的假图像A。返回:loss_d_a: Tensor, 判别器A的损失。false Tensor(False, dtype.bool_)true Tensor(True, dtype.bool_)d_fake_a net_d_a(fake_a)d_img_a net_d_a(img_a)loss_d_a gan_loss(d_fake_a, false) gan_loss(d_img_a, true)return loss_d_adef discriminator_forward_b(img_b, fake_b):判别器B的前向传播函数计算判别器B的损失。参数:img_b: Tensor, 真实图像B。fake_b: Tensor, 生成的假图像B。返回:loss_d_b: Tensor, 判别器B的损失。false Tensor(False, dtype.bool_)true Tensor(True, dtype.bool_)d_fake_b net_d_b(fake_b)d_img_b net_d_b(img_b)loss_d_b gan_loss(d_fake_b, false) gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区用来存储之前创建的50个图像
pool_size 50
def image_pool(images):图像缓冲池函数用于保存和随机返回假图像。参数:images: list of Tensor, 新生成的图像列表。返回:output: Tensor, 从缓冲池中选出的图像集合。num_imgs 0image1 []if isinstance(images, Tensor):images images.asnumpy()return_images []for image in images:if num_imgs pool_size:num_imgs num_imgs 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) 0.5:random_id random.randint(0, pool_size - 1)tmp image1[random_id].copy()image1[random_id] imagereturn_images.append(tmp)else:return_images.append(image)output Tensor(return_images, ms.float32)if output.ndim ! 4:raise ValueError(img should be 4d, but get shape {}.format(output.shape))return output计算梯度和反向传播
from mindspore import value_and_grad# 实例化求梯度的方法
grad_g_a value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())grad_d_a value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度反向传播更新参数
def train_step_g(img_a, img_b):# 在生成器训练步骤中冻结判别器的梯度计算net_d_a.set_grad(False)net_d_b.set_grad(False)# 生成器前向计算并获取损失fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib generator_forward(img_a, img_b)# 计算生成器A和B的梯度_, grads_g_a grad_g_a(img_a, img_b)_, grads_g_b grad_g_b(img_a, img_b)# 使用优化器更新生成器A和B的参数optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):# 在判别器训练步骤中开启判别器的梯度计算net_d_a.set_grad(True)net_d_b.set_grad(True)# 计算判别器A和B的损失和梯度loss_d_a, grads_d_a grad_d_a(img_a, fake_a)loss_d_b, grads_d_b grad_d_b(img_b, fake_b)# 计算判别器的平均损失loss_d (loss_d_a loss_d_b) * 0.5# 使用优化器更新判别器A和B的参数optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d模型训练
训练分为两个主要部分训练判别器和训练生成器在前文的判别器损失函数中论文采用了最小二乘损失代替负对数似然目标。
训练判别器训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 −()[(()−1)2]Ey−pdata(y)[(D(y)−1)2] 训练生成器如 CycleGAN 论文所述我们希望通过最小化 −()[((()−1)2]Ex−pdata(x)[(D(G(x)−1)2] 来训练生成器以产生更好的虚假图像。
%%time
import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype# 由于时间原因epochs设置为1可根据需求进行调整
epochs 1
save_step_num 80
save_checkpoint_epochs 1
save_ckpt_dir ./train_ckpt_outputs/print(Start training!)# 开始训练过程
for epoch in range(epochs):g_loss []d_loss []start_time_e time.time()# 遍历数据集中的每个样本for step, data in enumerate(dataset.create_dict_iterator()):start_time_s time.time()# 从数据中提取图像A和Bimg_a data[image_A]img_b data[image_B]# 训练生成器并得到生成的图像及损失res_g train_step_g(img_a, img_b)fake_a res_g[0]fake_b res_g[1]# 训练判别器并得到损失res_d train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))loss_d float(res_d.asnumpy())step_time time.time() - start_time_s# 将生成器和判别器的损失分别记录res []for item in res_g[2:]:res.append(float(item.asnumpy()))g_loss.append(res[0])d_loss.append(loss_d)# 每隔一定步数打印训练信息if step % save_step_num 0:print(fEpoch:[{int(epoch 1):3d}/{int(epochs):3d}], fstep:[{int(step):4d}/{int(datasize):4d}], ftime:{step_time:3f}s,\nfloss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, floss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, floss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, floss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f})# 计算并打印每个epoch的平均损失和时间信息epoch_cost time.time() - start_time_eper_step_time epoch_cost / datasizemean_loss_d, mean_loss_g sum(d_loss) / datasize, sum(g_loss) / datasizeprint(fEpoch:[{int(epoch 1):3d}/{int(epochs):3d}], fepoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, fmean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f})# 每隔一定epoch数保存检查点if epoch % save_checkpoint_epochs 0:os.makedirs(save_ckpt_dir, exist_okTrue)save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, fg_a_{epoch}.ckpt))save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, fg_b_{epoch}.ckpt))save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, fd_a_{epoch}.ckpt))save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, fd_b_{epoch}.ckpt))print(End of training!)模型推理
下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移结果中第一行为原图第二行为对应生成的结果图。
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net
import matplotlib.pyplot as plt
import numpy as np# 加载权重文件
# 参数 net网络模型
# 参数 ckpt_dir权重文件目录
# 无返回值
def load_ckpt(net, ckpt_dir):param_GA load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)g_a_ckpt ./CycleGAN_apple2orange/ckpt/g_a.ckpt
g_b_ckpt ./CycleGAN_apple2orange/ckpt/g_b.ckptload_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)# 图片推理
fig plt.figure(figsize(11, 2.5), dpi100)# 推理函数
# 参数 dir_path图片目录路径
# 参数 net网络模型
# 参数 a subplot起始位置偏移量
# 无返回值
def eval_data(dir_path, net, a):# 读取图片生成器def read_img():for dir in os.listdir(dir_path):path os.path.join(dir_path, dir)img Image.open(path).convert(RGB)yield img, dirdataset ds.GeneratorDataset(read_img, column_names[image, image_name])trans [vision.Resize((256, 256)), vision.Normalize(mean[0.5 * 255] * 3, std[0.5 * 255] * 3), vision.HWC2CHW()]dataset dataset.map(operationstrans, input_columns[image])dataset dataset.batch(1)for i, data in enumerate(dataset.create_dict_iterator()):img data[image]fake net(img)fake (fake[0] * 0.5 * 255 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))img (img[0] * 0.5 * 255 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))fig.add_subplot(2, 8, i1a)plt.axis(off)plt.imshow(img.asnumpy())fig.add_subplot(2, 8, i9a)plt.axis(off)plt.imshow(fake.asnumpy())eval_data(./CycleGAN_apple2orange/predict/apple, net_rg_a, 0)
eval_data(./CycleGAN_apple2orange/predict/orange, net_rg_b, 4)
plt.show()原论文1703.10593 (arxiv.org)
参考代码lab - JupyterLab (mindspore.cn)
参考资料
精读CycleGAN论文-拍案叫绝的非配对图像风格迁移_哔哩哔哩_bilibili