本篇文章将为大家带来WGAN(Wasserstein Generative Adversarial Networks ),旨在解决GAN难train的问题。这篇论文中有大量的理论推导,我不会带着大家一个个推,当然很多我也不会,但是我尽可能的把一些关键部分给大家叙述清楚,让大家从心底认可WGAN,觉得WGAN是合理的,是美妙的。
准备好了吗,就让我们一起来学学WGAN吧!🚖🚖🚖
上文我们谈到GAN普遍存在训练困难的问题,而WGAN是用来解决此问题的。所谓对症下药,我们第一步应该要知道GAN为什么会存在训练难的问题,这样我们才能从本质上对其进行改进。下面就跟随我的步伐一起来看看吧!!!🌼🌼🌼
还记得我们在[1]中给出的GAN网络的损失函数吗?如下图所示:





关于KL散度我就介绍这么多,它还有一些性质,像非对称性等等,我这里就不过多介绍了,感兴趣的可以自己去阅读阅读相关资料。


这时候肯定很多人会想了,为什么我们需要把式3化成式4这个JS散度的形式,这是因为我们的GAN训练困难很大原因就是这个JS散度捣的鬼。🍵🍵🍵
为什么说训练困难是JS散度捣的鬼呢?通过上文我们知道,我们将判别器训练的越好,即判别器越接近最优判别器,此时生成器的loss会等价为式4中的JS散度的形式,当我们训练生成器来最小化生成器损失时也就是最小化式4中的JS散度。这这过程看似非常的合理,只要我们不断的训练,真实分布和生成分布越来越接近,JS散度应该越来越小,直到两个分布完全一致,此时JS散度为0。






往重叠部分都是可以忽略的。



解释了这么多,大家应该知道为什么GAN会训练困难了吧。最后总结一下,这是因为当判别器训练最优时,生成器的损失函数等价于JS散度,其梯度往往一直为0,得不到更新,所以很难train。

种损失函数带来的缺陷我这里也叙述了。
我们上文详细分析了普通GAN存在的缺陷,主要是由于和JS散度相关的损失函数导致的。大佬们就在思考能否有一种损失能够代替JS散度呢?于是,WGAN应运而生,其提出了一种新的度量两个分布距离的标准——Wasserstein Metric,也叫推土机距离(Earth-Mover distance)。下面就让我们来看看什么是推土机距离吧!!!🌶🌶🌶
下图左侧有6个盒子,我们期望将它们都移动到右侧的虚线框内。比如将盒子1从位置1移动到位置7,移动了6步,我们就将此步骤的代价设为6;同理,将盒子5从位置3移动到位置10,移动了7步,那么此步骤的代价为7,依此类推。
图1 EM distance示例
很显然,我们有很多种不同的方案,下图给出了两种不同的方案:



在这个例子中,上图展示的两种方案的代价是不同的,一个为2,一个为6。,而推土机距离就是穷举所有的移动方案,最小的移动代价对应的就是推土机距离。对应本列来说,推土机距离等于2。
相信通过上文的表述,你已经对推土机距离有了一定了了解。现给出推土机距离是数学定义,如下:



现在我们已经知道了推土机距离是什么,但是我们还没解释清楚我们为什么要用推土机距离,即推土机距离为什么可以代替JS散度成为更优的损失函数?我们来看这样的一个例子,如下图所示:



很显然,lipschitz连续就限制了f的斜率的绝对值小于等于K,这个K称为Libschitz常数。我们来举个例子帮助大家理解,如下图所示:

上图中log(x)的斜率无界,故log(x)不满足lipschitz连续条件;而sin(x)斜率的绝对值都小于1,故sin(x)满足lipschitz连续条件。
这样,我们只需要找到一个lipschitz函数,就可以计算推土机距离了。至于怎么找这个lipschitz函数呢,就是我们搞深度学习的那一套啦,只需要建立一个深度学习网络来进行学习就好啦。实际上,我们新建立的判别器网络和之前的的基本是一致的,只是最后没有使用sigmoid函数,而是直接输出一个分数,这个分数可以反应输入图像的真实程度。
呼呼呼~~,WGAN的原理就为大家介绍到这里了,大家掌握了多少呢?其实我认为只看一篇文章是很难把WGAN的所有细节都理解的,大家可以看看本文的参考文献,结合多篇文章看看能不能帮助大家解决一些困惑。🧅🧅🧅
WGAN的代码实战我不打算贴出一堆代码了,只说明一下WGAN相较于普通GAN做了哪些改变。首先我们给出论文中训练WGAN的流程图,如下:

其实WGAN相较于原始GAN只做了4点改变,分别如下:
现在代码中分别对上述的4点做相关解释:【很简单,所以大家想要将原始GAN修改为WGAN就按照下列的几点来修改就好了】
1.判别器最后不使用sigmoid函数
这个我们一般只需要删除判别器网络中的最后一个sigmoid层就可以了,非常简单。但是我还想提醒大家一下,有时候你在看别人的原始GAN时,他的判别器网络中并没有sigmoid函数,而是在定义损失函数时使用了BCEWithLogitsLoss函数,这个函数会先对数据做sigmoid,相关代码如下:
# 定义损失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')如果这时你想删除sigmoid函数,只需将BCEWithLogitsLoss函数修改成BCELoss即可。关于BCEWithLogitsLoss和BCELoss的区别我这篇文章有相关讲解,感兴趣的可以看看。🌽🌽🌽
2.生成器和判别器的loss不取log
我们先来看原始GAN判别器的loss是怎么定义的,如下:
d_loss_real = criterion(d_out_real.view(-1), label_real)
d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
d_loss = d_loss_real + d_loss_fake原始GAN使用了criterion函数,这就是我们上文定义的BCEWithLogitsLoss或BCELoss,其内部是一个log函数。

d_loss = -(torch.mean(d_out_real.view(-1))-torch.mean(d_out_fake.view(-1)))
看完判别器的损失,我们再来看生成器的损失:
g_loss = criterion(d_out_fake.view(-1), label_real) #原始GAN损失
---------------------------------------------------
g_loss = -torch.mean(d_out_fake.view(-1)) #WGAN损失3.每次更新判别器参数后将判别器的权重截断
首先来说说没什么要进行权重截断,这是因为lipschitz连续条件不好确定,作者为了方便直接简单粗暴的限制了权重参数在[-c,c]这个范围了,这样就一定会存在一个常数K使得函数f满足lipschitz连续条件,具体的实现代码如下:
# clip D weights between -0.01, 0.01 权重剪裁
for p in D.parameters():
p.data.clamp_(-0.01, 0.01)
这样,WGAN的代码实战我就为大家介绍到这里,不管你前文的WGAN原理听明白了否,但是WGAN代码相信你是一定会修改的,改动的非常之少,大家快去试试吧。🍄🍄🍄
这部分的理论确实是有一定难度的,我也看了非常非常多的视频和博客,写了很多笔记。我觉得大家不一定要弄懂每一个细节,只要对里面的一些关键公式,关键思想有清晰的把握即可;而这部分的实验较原始GAN需要修改的仅有四点,非常简单,大家都可以试试。最后希望大家都能够有所收获,就像WGAN一样稳定的进步,一起加油吧!!!🥂🥂🥂
阅读量:1524
点赞量:0
收藏量:0