使用百度paddlepaddle框架 对应论文UGATIT的复现
selfie2anime数据集
args类定义了训练过程和网络相关的一些超参数,主程序则实例化了一个UGATIT类,调用其训练过程
创建了一个UGATIT类
UGATIT.build_model 定义了两个生成器和四个判别器以及生成器和判别器的优化器
UGATIT.train中定义了网络的损失函数,训练方法,并通过不断的迭代训练网络
UGATIT的主网络定义,其中包含了生成器 ResnetGenerator, 判别器 Discriminator
提供了将图片集转化为reader的接口,其中包含了图片的随机裁剪和随机翻转从而进行了数据增强
将paddle之中的mse_loss,bce_loss封装好为更易调用的函数
将读取数据转化为tensor的函数 以及将tensor转化为图片的函数
为了给rho 层编号而实现的全局函数,可以获得一个全局变量
解压函数
评估函数,将验证集A中的图片全部转为B,并将生成的图片存放到fake文件夹,可用GAN_Metrics-Tensorflow进行评估
存放网络参数的文件夹
其中包含了训练了大约30w轮的权重文件
存放训练过程之中输出的关于验证集的效果图片,在这里展示了大约30w轮时候的训练效果
通过在终端执行
python main.py
即可开始网络的训练
其中main.py 中的 args.start 设置训练的开始轮数, args.pretrain 设置是否需要加载预训练模型
执行
python eval.py
可以进行模型genA2B的验证,其中可以设置网络需要加载的权重文件