对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析-灵析社区

我不是魔法师

Spectral Normalization原理详解

  由于原始GAN网络存在训练不稳定的现象,究其本质,是因为它的损失函数实际上是JS散度,而JS散度不会随着两个分布的距离改变而改变,这就会导致生成器的梯度会一直不变,从而导致模型训练效果很差。WGAN为了解决原始GAN网络训练不稳定的现象,引入了EM distance代替原有的JS散度,这样的改变会使生成器梯度一直变化,从而使模型得到充分训练。但是WGAN的提出伴随着一个难点,即如何让判别器的参数矩阵满足Lipschitz连续条件。

 如何解决上述所说的难点呢?在WGAN中,我们采用了一种简单粗暴的方式来满足这一条件,即直接对判别器的权重参数进行剪裁,强制将权重限制在[-c,c]范围内。大家可以动动我们的小脑瓜想想这种权重剪裁的方式有什么样的问题——(滴,揭晓答案)如果权重剪裁的参数c很大,那么任何权重可能都需要很长时间才能达到极限,从而使训练判别器达到最优变得更加困难;如果权重剪裁的参数c很小,这又容易导致梯度消失。因此,如何确定权重剪裁参数c是重要的,同时这也是困难的。WGAN提出之后,又提出了WGAN-GP来实现Lipschitz 连续条件,其主要通过添加一个惩罚项来实现。【关于WGAN-GP我没有做相关教程,如果不明白的可以评论区留言】那么本文提出了一种归一化的手段Spectral Normalization来实现Lipschitz连续条件,这种归一化具体是怎么实现的呢,下面听我慢慢道来。

 这样,其实我们的Spectral Normalization原理就讲的差不多了,最后我们要做的就是求得每层参数矩阵的谱范数,然后再进行归一化操作。要想求矩阵的谱范数,首先得求矩阵的奇异值,具体求法我放在附录部分。

  但是按照正常求奇异值的方法会消耗大量的计算资源,因此论文中使用了一种近似求解谱范数的方法,伪代码如下图所示:

  在代码的实战中我们就是按照上图的伪代码求解谱范数的,届时我们会为大家介绍。

注:大家阅读这部分有没有什么难度呢,我觉得可能还是挺难的,你需要一些矩阵分析的知识,我已经尽可能把这个问题描述的简单了,有的文章写的很好,公式推导的也很详尽,我会在参考链接中给出。但是会涉及到最优化的一些理论,估计这就让大家更头疼了,所以大家慢慢消化吧!!!在最后的附录中,我会给出本节内容相关的矩阵分析知识,是我上课时的一些笔记,笔记包含本节的知识点,但针对性可能不是很强,也就是说可能包含一些其它内容,大家可以选择忽略,当然了,你也可以细细的研究研究每个知识点,说不定后面就用到了呢!!!

Spectral Normalization源码解析

源码下载地址:Spectral Normalization

  这个代码使用的是CIFAR10数据集,实现的是一般生成对抗网络的图像生成任务。我不打算再对每一句代码进行详细的解释,有不明白的可以先去看看我专栏中的其它GAN网络的文章,都有源码解析,弄明白后再看这篇你会发现非常简单。那么这篇文章我主要来介绍一下Spectral Normalization部分的内容,其相关内容在spectral_normalization.py文件中,我们理论部分提到Spectral Normalization关键的一步是求解每个参数矩阵的谱范数,相关代码如下:

def _update_u_v(self):
    u = getattr(self.module, self.name + "_u")
    v = getattr(self.module, self.name + "_v")
    w = getattr(self.module, self.name + "_bar")
    height = w.data.shape[0]
    for _ in range(self.power_iterations):
        u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))  
        v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))

    sigma = u.dot(w.view(height, -1).mv(v))
    setattr(self.module, self.name, w / sigma.expand_as(w))
    
    
    
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

  对上述代码做一定的解释,6,7,8,9,10行做的就是理论部分伪代码的工作,最后会得到谱范数sigma。11行为使用参数矩阵除以谱范数sigma,以此实现归一化的作用。【torch.mv实现的是矩阵乘法的操作,里面可能还有些函数你没见过,大家百度一下用法就知道了,非常简单】

其实关键的代码就这些,是不是发现特别简单呢🍸🍸🍸每次介绍代码时我都会强调自己动手调试的重要性,很多时候写文章介绍源码都觉得有些力不从心,一些想表达的点总是很难表述,总之,大家要是有什么不明白的就尽情调试叭,或者评论区留言,我天天在线摸鱼滴喔。后期我也打算出一些视频教学了,这样的话就可以带着大家一起调试,我想这样介绍源码彼此都会轻松很多。🛩🛩🛩

小结

  Spectral Normalization确实是有一定难度的,我也有许多地方理解的也不是很清楚,对于这种难啃的问题我是这样认为的。我们可以先对其有一个大致的了解,知道整个过程,知道代码怎么实现,能使用代码跑通一些模型,然后考虑能否将其用在自己可能需要使用的地方,如果加入的效果不好,我们就没必要深究原理了,如果发现效果好,这时候我们再回来慢慢细嚼原理也不迟。最后,希望各位都能获取新知识,能够学有所成叭!!!

附录

  这部分是我学习矩阵分析这门课程时的笔记,截取一些包含此部分的内容,有需求的感兴趣的可以看一看。


阅读量:2017

点赞量:0

收藏量:0