You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import jittor as jt
from jittor import init
import argparse
import os
import numpy as np
import math
from jittor import nn
# 检查是否启用了CUDA
if jt.has_cuda:
jt.flags.use_cuda = 1
# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=1000, help='interval between image sampling')
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
# 生成器模型
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 标签的嵌入向量维度与标签的数量相同10*10的embedding
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
# 定义生成器模型结构,包含多个线性层和激活函数
# nn.Linear(in_dim, out_dim)表示全连接层
# in_dim输入向量维度
# out_dim输出向量维度
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
# 生成器接收两个输入噪声向量latent vector和类别标签class label。opt.latent_dim 表示噪声向量的维度opt.n_classes 表示类别标签的数量。
# img_shape为(1, 32, 32)则np.prod(img_shape)为32x32=1024将最后的线性层的输出维度设置为图像形状的元素个数以确保生成的向量能够正确地表示生成的图像。
# 逐步增加块的好处是可以逐渐引入更多的非线性变换,从而更好地捕捉数据中的复杂模式
# nn.Sequential是一个用于构建神经网络模型的容器。它按照顺序将一系列的层组合在一起形成一个网络模型。在代码中self.model = nn.Sequential(...)创建了一个名为model的nn.Sequential实例并将多个层按照顺序添加到这个实例中。
self.model = nn.Sequential(*block((opt.latent_dim + opt.n_classes), 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh())
def execute(self, noise, labels):
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
# img 是一个形状为 (batch_size, 1024) 的向量,其中 batch_size
img = self.model(gen_input)
# 将img从1024维向量变为32*32矩阵
# img.view() 是 PyTorch 中的一个函数用于调整张量的形状。在这里img 是一个张量img.shape 返回该张量的形状。img.shape[0] 表示张量的第一个维度,也就是 batch_size即样本数量。
# *img_shape 是一个展开操作符,它将 img_shape 这个元组展开为多个独立的参数。假设 img_shape = (channels, img_size, img_size),那么 *img_shape 就等同于 (channels, img_size, img_size)。
img = img.view((img.shape[0], *img_shape))
return img
# 判别器模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
# 定义判别器模型结构,包含多个线性层和激活函数
# 图像数据展平后的大小加上类别标签的大小
self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
# TODO: 添加最后一个线性层,最终输出为一个实数
nn.Linear(512,1)
)
def execute(self, img, labels):
# batch_size 表示批量大小。(-1) 的作用是自动计算该维度的大小,以保持数据总量不变。图像数据的维度是 (batch_size, channels * height * width)
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
# TODO: 将d_in输入到模型中并返回计算结果
img = self.model(d_in)
return img
# 损失函数:平方误差
# 调用方法adversarial_loss(网络输出A, 分类标签B)
# 计算结果:(A-B)^2
adversarial_loss = nn.MSELoss()
# 创建生成器和判别器实例
generator = Generator()
discriminator = Discriminator()
# 导入MNIST数据集
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
# 定义了一个图像数据的预处理变换 transform。这个变换包括将图像大小调整为 opt.img_size、转换为灰度图像、并对图像进行归一化处理。这些变换可以帮助数据在输入模型之前进行预处理和标准化以提高训练效果。
transform = transform.Compose([
transform.Resize(opt.img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
# 具体来说dataloader是一个用于数据批次处理的工具它从MNIST数据集中加载训练数据并进行一些预处理操作如图像转换。下面是对该代码的解释
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
# 生成器和判别器的优化器,生成器的优化器 optimizer_G使用了 Adam 优化算法。optimizer_G 的参数是生成器的参数 generator.parameters()lr 参数指定了学习率betas 参数是 Adam 优化算法中的两个指数衰减因子。
# Adam算法中betas 参数中的第一个元素 opt.b1 控制了一阶矩估计的衰减率,第二个元素 opt.b2 控制了二阶矩估计的衰减率。
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
from PIL import Image
# 保存生成的图片
def save_image(img, path, nrow=10, padding=5):
# 行代码从输入的图像数据img中获取图像的维度信息N表示图像数量C表示通道数W表示图像的宽度H表示图像的高度。
N,C,W,H = img.shape
# 这个条件判断语句检查图像数量N是否能被nrow整除如果不能整除则打印一条错误信息并返回。
if (N%nrow!=0):
print("N%nrow!=0")
return
# 这行代码计算每行显示的图像数量ncol将图像数量N除以每行显示的图像数量nrow然后取整数部分。
ncol=int(N/nrow)
# 这行代码创建一个空列表img_all用于存储拼接后的图像块。
img_all = []
# 在内层循环中将图像数据img中的图像按行分组并通过img_.append(img[i*nrow+j])将每个图像添加到临时列表img_中。同时通过img_.append(np.zeros((C,W,padding)))在每个图像之后添加一个零值数组作为间距。
for i in range(ncol):
img_ = []
for j in range(nrow):
img_.append(img[i*nrow+j])
img_.append(np.zeros((C,W,padding)))
img_all.append(np.concatenate(img_, 2))
img_all.append(np.zeros((C,padding,img_all[0].shape[2])))
img = np.concatenate(img_all, 1)
img = np.concatenate([np.zeros((C,padding,img.shape[2])), img], 1)
img = np.concatenate([np.zeros((C,img.shape[1],padding)), img], 2)
# 通过(img-min_)/(max_-min_)*255的计算公式将图像数据归一化到0-255的范围以便在保存图像时使用。
min_=img.min()
max_=img.max()
img=(img-min_)/(max_-min_)*255
# 这行代码对图像数据进行了维度转换,将维度顺序从(通道,宽度,高度)转换为(宽度,高度,通道)。
img=img.transpose((1,2,0))
'''
这是一个条件判断语句根据通道数C的不同进行处理。
如果通道数为3表示图像是RGB格式通过img[:,:,::-1]将图像的通道顺序从BGR转换为RGB。
如果通道数为1表示图像是灰度格式通过img[:,:,0]将图像从三维数组转换为二维数组。
'''
if C==3:
img = img[:,:,::-1]
elif C==1:
img = img[:,:,0]
# 使用Image.fromarray将NumPy数组转换为PIL图像对象然后使用save方法保存图像到指定路径。
Image.fromarray(np.uint8(img)).save(path)
# 生成样本图片
def sample_image(n_row, batches_done):
# 随机采样输入并保存生成的图片
# 生成了一个形状为 (n_row ** 2, opt.latent_dim) 的随机向量 z。它使用 np.random.normal 函数从正态分布中采样随机数,并通过 jt.array 转换为 Jittor 的张量类型。然后将其转换为浮点数类型,并停止梯度传播
# z噪声是用于生成模型如生成对抗网络中的输入噪声。它是一个随机向量通常服从某种特定的概率分布如均匀分布或正态分布。通过将噪声向量输入生成器模型可以生成具有多样性的样本
z = jt.array(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).float32().stop_grad()
# 创建了一个标签向量 labels用于条件生成。它是一个包含重复数字的数组这样可以按照一定的规律变化来控制生成的图像特征。标签向量的形状为 (n_row ** 2,)。同样地,它也通过 jt.array 转换为 Jittor 的张量类型,并转换为浮点数类型,并停止梯度传播。
# labels 的维度实际上是 (n_row ** 2, 1),即一个二维的向量。它的长度为 n_row 的平方,表示生成图像的数量。每个元素是一个长度为 1 的向量,用于标识生成的图像的特定属性或类别。
labels = jt.array(np.array([num for _ in range(n_row) for num in range(n_row)])).float32().stop_grad()
# 使用生成器模型 generator 来生成图像数据。它将随机向量 z 和标签向量 labels 作为输入,通过调用生成器模型的前向传播方法得到生成的图像数据。生成的图像数据被赋值给变量 gen_imgs。
gen_imgs = generator(z, labels)
save_image(gen_imgs.numpy(), "%d.png" % batches_done, nrow=n_row)
# ----------
# 模型训练
# ----------
# 模型训练
for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.shape[0]
# 数据标签valid=1表示真实的图片fake=0表示生成的图片
# valid是一个形状为[batch_size, 1]的Jittor张量其中batch_size表示每个批次的图像数量。它的值被设置为1表示这些图像是真实的图片。使用.float32()方法将数据类型设置为浮点型,并使用.stop_grad()方法停止梯度计算。
# fake是一个形状为[batch_size, 1]的Jittor张量也表示数据标签。它的值被设置为0表示这些图像是生成的图片。同样使用.float32()方法将数据类型设置为浮点型,并使用.stop_grad()方法停止梯度计算。
valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()
# 真实图片及其类别
real_imgs = jt.array(imgs)
labels = jt.array(labels)
# -----------------
# 训练生成器
# -----------------
# 采样随机噪声和数字类别作为生成器输入
# z 是一个形状为 (batch_size, opt.latent_dim) 的张量,其中 batch_size 表示批次的大小opt.latent_dim 表示噪声向量的维度
z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
# gen_labels 是一个张量,其维度为 (batch_size, ),表示一个一维向量。
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
# 生成一组图片
gen_imgs = generator(z, gen_labels)
# 损失函数衡量生成器欺骗判别器的能力即希望判别器将生成图片分类为valid
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.sync()
optimizer_G.step(g_loss)
# ---------------------
# 训练判别器
# ---------------------
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real,valid)
validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake,fake)
# 总的判别器损失
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
# 优化判别器参数
optimizer_D.step(d_loss)
# 打印损失值
if i % 50 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)
)
# 保存生成的图片
batches_done = epoch * len(dataloader) + i
# 每完成 opt.sample_interval 个批次时生成一张图片,每隔 1000 个批次opt.sample_interval 的值为 1000生成一张图片用于可视化生成器模型的训练效果。这有助于观察生成图像的质量和多样性随着训练的进行而如何改变。
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)
if epoch % 10 == 0:
generator.save("generator_last.pkl")
discriminator.save("discriminator_last.pkl")
# 加载生成器和判别器模型
generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')
# 生成指定数字序列的图片
number = "14034541402257"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)
img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_=img_array.min()
max_=img_array.max()
img_array=(img_array-min_)/(max_-min_)*255
Image.fromarray(np.uint8(img_array)).save("result.png")