对抗生成网络GAN系列——GAN原理及手写数字生成小案例-灵析社区

我不是魔法师

GAN简介

  这里先来简单的介绍一下GAN,其完整的名称为Generative Adversarial Nets (生成对抗网络) 。其实这个起名还有个小故事,我简要的说一下,大家随便听听,就当放松了。当时作者Goodfellow  对于这篇文章其实是有好几个备选名字的,后来一个中国人说GAN(干)在中国有一种对抗的意思,作者一听,直接拍案选择了这个名称。

​   接下来让我们看看论文中对GAN的解释,如下图所示:


  我简单的来翻译一下,其大致意思是说:在我们提出的对抗生成网络中,有一个生成模型,也有一个对抗模型,它们互相对抗,互相促进。文中也举了个小例子,生成模型可以被认为是一个假币伪造团队,试图生产假币并使用,而判别器类似于警察,试图发现假币。这就是一个互相博弈的过程,生成模型不断的产生伪造水平高的假币,而判别器不断提高警察识别假币水平,直至两者达到一个平衡。这个平衡是指什么呢?即判别器对于生成模型产生的假币辨别的成功率大致为50%,即很难辨别真假。

生成对抗网络

GAN损失函数

​  这部分我们主要结合生成对抗网络的损失函数来介绍网络的整个流程,首先呢,我们需要对一些字母做一些解释。如下:

对上述字母有一定的了解后,下面就可以给出生成对抗网络的损失函数了,如下图所示:

      乍一看这个公式你应该是懵逼的,下面就跟着我的思路来分解分解上述公式。首先这个公式应该有两部分,一部分为给定G,找到使V最大化的D;另一部分为给定D,找到使V最小化的G。

​  我们先来看第一部分,即给定G,找到使V最大化的D。如下图所示:【注:我们为什么想要找到使V最大化的D,是因为使V最大化的D会使判别器的效果最好】

接着我们来看第二部分,即给定D,找到使V最小化的G。如下图所示:【注:我们为什么想要找到使V最小化的G,是因为使V最小化的G会使生成器的效果最好】

GAN流程

​  论文中在给出损失函数后,又给了一个图例来解释GAN的过程,用原文的话来说就是一个不怎么正式,却更具教学意义的解释。(See Figure 1 for a less formal, more pedagogical explanation of the approach  )

接下来论文中给出了训练GAN网络的伪代码,如下图所示:


上面四个图中,注意黄框框住的并不是GAN生成的图片,它们表示与GAN生成图片最相似的原始真实图片。而GAN生成的图片为黄框左侧第一张图片,可以看出,GAN生成的效果还是挺好的。

使用GAN生成手写数字小demo

  上文算是把原理讲述清楚了,若你还不明白,慢慢的阅读每句话,加入自己的思考,或许会有不一样的收获。那么这节我讲来讲讲通过GAN网络生成手写数字的小demo,通过这部分你会了解搭建GAN网络的基本流程。下面就让我们一起来学学吧!!!

首先训练一个模型肯定少不了数据集,我们通过一下代码获取torch自带的MNIST数据集,代码如下:

#MNIST数据集获取
dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True,
                                     transform=torchvision.transforms.Compose(
                                         [
                                             torchvision.transforms.Resize(28),
                                             torchvision.transforms.ToTensor(),
                                             torchvision.transforms.Normalize([0.5], [0.5]),
                                         ]
                                                                             )
                                     )

 之后我们通过DataLoader方法加载数据集,代码如下:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

​   这样数据就准备好了,下面就来构建我们的模型,分为生成器(Generator)和判别器(Discriminator)。【注:由于这期算是入门GAN,所以模型搭建只采用了全连接层】

生成器模型搭建:

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.GELU(),

            nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.GELU(),
            nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.GELU(),
            nn.Linear(512, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Sigmoid(),
        )

    def forward(self, z):
        # shape of z: [batchsize, latent_dim]

        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)

        return image

判别器模型搭建:


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32), 512),
            torch.nn.GELU(),
            nn.Linear(512, 256),
            torch.nn.GELU(),
            nn.Linear(256, 128),
            torch.nn.GELU(),
            nn.Linear(128, 64),
            torch.nn.GELU(),
            nn.Linear(64, 32),
            torch.nn.GELU(),
            nn.Linear(32, 1),
            nn.Sigmoid(),
        )

    def forward(self, image):
        # shape of image: [batchsize, 1, 28, 28]

        prob = self.model(image.reshape(image.shape[0], -1))

        return prob

​模型搭建好后,我们会对损失函数、优化器等参数进行设置:

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
loss_fn = nn.BCELoss()

​需要注意,这里采用的是BCELOSS损失函数,这个函数其实就对应着我们GAN理论部分的损失函数

​这些设置好后,我们就来训练我们的GAN网络了,相关代码如下:

num_epoch = 200
for epoch in range(num_epoch):
    for i, mini_batch in enumerate(dataloader):
        gt_images, _ = mini_batch


        z = torch.randn(batch_size, latent_dim)


        pred_images = generator(z)
        g_optimizer.zero_grad()

      
        g_loss = loss_fn(discriminator(pred_images), labels_one)

        g_loss.backward()
        g_optimizer.step()

        d_optimizer.zero_grad()

        real_loss = loss_fn(discriminator(gt_images), labels_one)
        fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
        d_loss = (real_loss + fake_loss)

        # 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了

        d_loss.backward()
        d_optimizer.step()

​   最后,我来展示一下训练结果吧!!!我是在服务器上进行训练的,所以还是比较快的。先来看一下初始的图,都是一些随机的噪声,如下图所示:


再来看训练一段时间的结果,发现效果还是蛮不错滴

阅读量:2039

点赞量:0

收藏量:0