对抗生成网络GAN系列——WGAN原理及实战演练-灵析社区

我不是魔法师

本篇文章将为大家带来WGAN(Wasserstein Generative Adversarial Networks ),旨在解决GAN难train的问题。这篇论文中有大量的理论推导,我不会带着大家一个个推,当然很多我也不会,但是我尽可能的把一些关键部分给大家叙述清楚,让大家从心底认可WGAN,觉得WGAN是合理的,是美妙的。

​  准备好了吗,就让我们一起来学学WGAN吧!🚖🚖🚖

WGAN原理详解

GAN为什么训练困难

​  上文我们谈到GAN普遍存在训练困难的问题,而WGAN是用来解决此问题的。所谓对症下药,我们第一步应该要知道GAN为什么会存在训练难的问题,这样我们才能从本质上对其进行改进。下面就跟随我的步伐一起来看看吧!!!🌼🌼🌼

​  还记得我们在[1]中给出的GAN网络的损失函数吗?如下图所示:

关于KL散度我就介绍这么多,它还有一些性质,像非对称性等等,我这里就不过多介绍了,感兴趣的可以自己去阅读阅读相关资料。

  这时候肯定很多人会想了,为什么我们需要把式3化成式4这个JS散度的形式,这是因为我们的GAN训练困难很大原因就是这个JS散度捣的鬼。🍵🍵🍵

​  为什么说训练困难是JS散度捣的鬼呢?通过上文我们知道,我们将判别器训练的越好,即判别器越接近最优判别器,此时生成器的loss会等价为式4中的JS散度的形式,当我们训练生成器来最小化生成器损失时也就是最小化式4中的JS散度。这这过程看似非常的合理,只要我们不断的训练,真实分布和生成分布越来越接近,JS散度应该越来越小,直到两个分布完全一致,此时JS散度为0。

往重叠部分都是可以忽略的。

 解释了这么多,大家应该知道为什么GAN会训练困难了吧。最后总结一下,这是因为当判别器训练最优时,生成器的损失函数等价于JS散度,其梯度往往一直为0,得不到更新,所以很难train。

种损失函数带来的缺陷我这里也叙述了。

EM distance(推土机距离)的引入

​  我们上文详细分析了普通GAN存在的缺陷,主要是由于和JS散度相关的损失函数导致的。大佬们就在思考能否有一种损失能够代替JS散度呢?于是,WGAN应运而生,其提出了一种新的度量两个分布距离的标准——Wasserstein Metric,也叫推土机距离(Earth-Mover distance)。下面就让我们来看看什么是推土机距离吧!!!🌶🌶🌶

​  下图左侧有6个盒子,我们期望将它们都移动到右侧的虚线框内。比如将盒子1从位置1移动到位置7,移动了6步,我们就将此步骤的代价设为6;同理,将盒子5从位置3移动到位置10,移动了7步,那么此步骤的代价为7,依此类推。

 图1 EM distance示例

​  很显然,我们有很多种不同的方案,下图给出了两种不同的方案:

 在这个例子中,上图展示的两种方案的代价是不同的,一个为2,一个为6。,而推土机距离就是穷举所有的移动方案,最小的移动代价对应的就是推土机距离。对应本列来说,推土机距离等于2。

​  相信通过上文的表述,你已经对推土机距离有了一定了了解。现给出推土机距离是数学定义,如下:


现在我们已经知道了推土机距离是什么,但是我们还没解释清楚我们为什么要用推土机距离,即推土机距离为什么可以代替JS散度成为更优的损失函数?我们来看这样的一个例子,如下图所示:

WGAN的实现

很显然,lipschitz连续就限制了f的斜率的绝对值小于等于K,这个K称为Libschitz常数。我们来举个例子帮助大家理解,如下图所示:

上图中log(x)的斜率无界,故log(x)不满足lipschitz连续条件;而sin(x)斜率的绝对值都小于1,故sin(x)满足lipschitz连续条件。

​  这样,我们只需要找到一个lipschitz函数,就可以计算推土机距离了。至于怎么找这个lipschitz函数呢,就是我们搞深度学习的那一套啦,只需要建立一个深度学习网络来进行学习就好啦。实际上,我们新建立的判别器网络和之前的的基本是一致的,只是最后没有使用sigmoid函数,而是直接输出一个分数,这个分数可以反应输入图像的真实程度。

​  呼呼呼~~,WGAN的原理就为大家介绍到这里了,大家掌握了多少呢?其实我认为只看一篇文章是很难把WGAN的所有细节都理解的,大家可以看看本文的参考文献,结合多篇文章看看能不能帮助大家解决一些困惑。🧅🧅🧅

WGAN代码实战

​  WGAN的代码实战我不打算贴出一堆代码了,只说明一下WGAN相较于普通GAN做了哪些改变。首先我们给出论文中训练WGAN的流程图,如下:

其实WGAN相较于原始GAN只做了4点改变,分别如下:

  1. 判别器最后不使用sigmoid函数
  2. 生成器和判别器的loss不取log
  3. 每次更新判别器参数后将判别器的权重截断
  4. 不适应基于动量的优化算法,推荐使用RMSProp

现在代码中分别对上述的4点做相关解释:【很简单,所以大家想要将原始GAN修改为WGAN就按照下列的几点来修改就好了】

1.判别器最后不使用sigmoid函数

这个我们一般只需要删除判别器网络中的最后一个sigmoid层就可以了,非常简单。但是我还想提醒大家一下,有时候你在看别人的原始GAN时,他的判别器网络中并没有sigmoid函数,而是在定义损失函数时使用了BCEWithLogitsLoss函数,这个函数会先对数据做sigmoid,相关代码如下:

# 定义损失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')

如果这时你想删除sigmoid函数,只需将BCEWithLogitsLoss函数修改成BCELoss即可。关于BCEWithLogitsLossBCELoss的区别我这篇文章有相关讲解,感兴趣的可以看看。🌽🌽🌽

2.生成器和判别器的loss不取log

我们先来看原始GAN判别器的loss是怎么定义的,如下:

d_loss_real = criterion(d_out_real.view(-1), label_real)
d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
d_loss = d_loss_real + d_loss_fake

原始GAN使用了criterion函数,这就是我们上文定义的BCEWithLogitsLossBCELoss,其内部是一个log函数。

d_loss = -(torch.mean(d_out_real.view(-1))-torch.mean(d_out_fake.view(-1)))

看完判别器的损失,我们再来看生成器的损失:

g_loss = criterion(d_out_fake.view(-1), label_real)  #原始GAN损失
---------------------------------------------------
g_loss = -torch.mean(d_out_fake.view(-1))            #WGAN损失

3.每次更新判别器参数后将判别器的权重截断

首先来说说没什么要进行权重截断,这是因为lipschitz连续条件不好确定,作者为了方便直接简单粗暴的限制了权重参数在[-c,c]这个范围了,这样就一定会存在一个常数K使得函数f满足lipschitz连续条件,具体的实现代码如下:

# clip D weights between -0.01, 0.01  权重剪裁
  for p in D.parameters():
      p.data.clamp_(-0.01, 0.01)


  • 注意这步是在判别器每次反向传播更新梯度结束后进行的。还需要提醒大家一点,在训练WGAN时往往会多训练几次判别器然后再训练一次生成器,论文中是训练5次判别器后训练一次生成器,关于这一点在上文WGAN的流程图中也有所体现。为什么要这样做呢?我思考这样做可能是想把判别器训练的更好后再来训练生成器,因为在前面的理论部分我们的推导都是建立在最优判别器的前提下的。🍵🍵🍵
  • 4.不适应基于动量的优化算法,推荐使用RMSProp
  • 这部分就属于玄学部分了,作者做实验发现像Adam这类基于动量的优化算法效果不好,然后使用了RMSProp优化算法。我决定这部分大家也不要纠结,直接用就好。同样我们来看看代码是怎么写的,如下:
  •  这样,WGAN的代码实战我就为大家介绍到这里,不管你前文的WGAN原理听明白了否,但是WGAN代码相信你是一定会修改的,改动的非常之少,大家快去试试吧。🍄🍄🍄

    小结

    ​  这部分的理论确实是有一定难度的,我也看了非常非常多的视频和博客,写了很多笔记。我觉得大家不一定要弄懂每一个细节,只要对里面的一些关键公式,关键思想有清晰的把握即可;而这部分的实验较原始GAN需要修改的仅有四点,非常简单,大家都可以试试。最后希望大家都能够有所收获,就像WGAN一样稳定的进步,一起加油吧!!!🥂🥂🥂


    阅读量:1524

    点赞量:0

    收藏量:0