前言

最近在看diffusion policy 的论文,感觉自己对于扩散模型的看法有些偏差。
从我现在的角度来看,扩散已经不仅仅是生成式模型的实现框架,扩散可以用于任何需要建模复杂分布输出的任务。
思路打开,扩散模型在我这里的重要性再上一个台阶。
diffusion policy的代码实现很难读,其中有很多为了方便多任务所做的兼容性设计。
所以为了完全弄懂 diffusion policy,我选择往回看,找一篇经典的里程碑式的扩散模型熟悉结构。
这就是DiT。
对于我来说,DiT有什么新颖的信息增益呢?
我总结为下面几条:
1.transformer 架构实现
2.cfg 无分类器的条件引导

所以我下面就会围绕这几点来讲


transformer 架构实现

扩散模型的框架我其实比较熟悉了,简单的介绍一下就是
让模型能够根据噪声等级(t),加噪的图像(xt)预测 epsilont(xt-x0),也就是纯净的图像变成加噪图像所加的噪声。
但是这个噪声的精确度肯定是不行的,于是我们不能直接用xt减去epsilont一步到位求x0,必须通过贝叶斯公式求出xt-1在给定x0,xt下的条件概率公式,应用概率公式以及xt以及 根据xt求出的x0,就能计算出xt-1,从而将噪声去掉一小部分。不断重复,就能逐渐求出x0.

其实我觉得上面的解释不是很好,我更喜欢用图像分布的对数的分数来理解扩散模型,因为这让我可以很轻易的将加噪这个过程看作是分数场的延申(将分数场强行理解为二维的图像,加噪就是将图像在第三维度串联,每一次加噪产生的新的分数场都会更加的均匀)。
这里就不多说了,有时间写一篇理论分析的博客。

说了这么多,讲回正题。
扩散模型预测的是噪声(xt相对于x0)或者说分数,这很明显就是一个稠密预测的任务,所以大家普遍运用的backbone都是unet。
关于稠密预测任务我在计算机视觉算法 DPT(Vision Transformers for Dense Prediction) 论文解读 | abstcol’s blogs中探讨过,transformer是不适合做这一类任务的,这是因为图像的一开始的embedding造成的信息损失太大了,很难上采样恢复。

当时dpt的做法是强行在transformer后加上一个cnn的头,并且利用cnn对transformer各层的输出tokens进行一系列上下采样的操作,让它们变得和unet的残差一样有不同的分辨率,之后就和unet操作一样,将各层不同分辨率的tokens逐个送到cnn的头里,慢慢上采样恢复分辨率。

但是在DiT中却不是这样的。
它利用了隐空间扩散模型这一思想,用现有的vae模型提前将图片压缩成小尺寸的特征图,因此embedding的时候完全可以一个像素当作一个token(我不知道为什么作者没有进行尝试,作者最小是2个像素一个token),那么transformer相对于unet的劣势自然就不存在了。

(注意:这里的vae是冻结权重的,也就是不进行训练,这有好处也有坏处,好处就是收敛肯定加快了,坏处是上限受vae影响,并且很难通过单独强化DiT提升)
(但是利用的vae难道不是unet框架吗?不得不说在稠密预测任务中完全抛弃cnn太难了)

说了这么多,我们还是来看一下图吧。

我们先来看一下整体的DiT架构

先看输入
noised latent 就是不同的噪声等级的特征图,patchify其实就是embedding,不是很懂为什么要造一个新词。(当然别忘了位置编码)

label y和噪声程度t都是被嵌入为token。其中噪声程度t的嵌入是不用学的,直接就是位置嵌入。label y则是需要学习。
两者计算后相加,变成了一个token,参与接下来的transformer 运算。

再看输出
因为embedding并不是一个像素一个token,所以肯定还是要上采样的。
DiT是先将每个token内部的维度通过公用的mlp变成patch乘patch乘inchannel2 ,再重新排列回原来大小。
有些人可能好奇,为什么要乘2?
这是因为这个网络比较奇特,连噪声/分数的方差都想预测。
我是有点好奇的,预测方差的必要性体现在哪里呢?

这一部分其实没什么好说的,要说的话我有一点好奇,为什么上采样要使用mlp呢,mlp如何复原位
置信息?但是话又说回来了,压缩的时候损失的信息,你用cnn也恢复不了,所以倒是无所谓了。

ok,接下来我们讨论DiT的内部attention block,这个部分可以说是比较重要了。
为什么呢?
因为DiT不论是scale的实验,还是融入条件信息的实验,都是通过修改DiT来实现的。
(当然在patchify也有设计embedding下采样率的影响的实验)

我们先来看图吧。

其实看图的话第二幅图和第三幅图都是很直白的了。
第二幅图就是把conditioning也就是条件当作key和value进行交叉注意力,第三幅图就是把conditioning当作普通的token,正常的进行自注意力。

我们来仔细地看下第一幅图,其实也很明了。mlp将t的维度乘六,分别进行线性变换(scale表示元素乘,shift表示元素加)。
这里有一个小细节需要注意一下。
因为layernorm自然的是有缩放、平移参数可以学习的,这其实和条件变换的功能重合了,所以在实现的时候作者将layernorm的可学习参数取消了。

所以这就是三种cfg的结构了,作者对于这三种结构也进行了消融实验。

可以看到的是adaLN(也就是第一幅图)的效果最好,而且这种效果差距是不会随着训练消失的。
而蓝色的线是什么呢,其实就是在adaLN的基础上要求参数进行特殊初始化使得输出初始为0,也就是假设噪声为0,这是前人实验证明有效的。
我们可以从直觉上理解这一点,随机初始化纯纯相当于加噪,大概率是拉大和学习目标的距离的,并且随着维度的上升,这个概率会越大。
反过来想,如果加噪可以拉近和学习目标的距离,那么我们不断加噪就可以去噪了,所以至少加噪是没有好处的。

我觉得是可以用数学证明高斯分布取值与学习目标的距离的期望是大于0和学习目标的距离的,我让gpt证明了一维的情况,有理由相信高维也可以证明。

这里我觉得最值得关注的是初始化的影响竟然不能通过训练抵消!
这代表一个好的初始化所能带来的影响是永久的,至少在有限的训练时间的情况下。

关于这一点其实读者可以看一篇论文:Torch.manual_seed(3407) is all you need: On the influence of random seeds in deep learning architectures for computer vision | Abstract


cfg 无分类器的条件引导

cfg的引导其实前面在将transformer的结构的时候已经讲过了,但是cfg的引导并不只是将类别信息融入进去就可以了,我们完全可以通过后处理,来改变引导的强度,甚至实现多标签引导,反向引导,当然了,代价就是计算量的大大增加。

简单说一下cfg的后处理:对于一次带有类别引导的采样,采样的时候会将样本复制一份,对应的类别标签设为null(这个null的embedding其实是学习到的,也就是说我们会在训练的时候偶尔用null对应的embedding来计算损失函数,因为学习时对应的图片的类别不固定,所以就可以算作无类别embedding了),两个样本同时计算噪声,最终将两者的噪声进行加权求和。
具体来说就是如下式子:

通过改变s,就能改变类别引导的强度,如果将复制那份的标签换做真实的标签,其实就可以实现多标签引导(s在零到一的范围内)或者标签引导同时进行标签反向引导(s大于一)。

作者在实验中发现对于vae产生的特征图的4个通道中的三个进行如上计算就可以了。(但是要对s系数进行相应的调整)具体来说如果对全部4个通道进行调整就需要将s相对于三通道时取四分之三。
实事求是讲,这有什么意义吗?(我在看代码的时候被这里搞蒙了)

作者还发现如果把s调整的太大了,那么稳定性会上升,多样性会下降。这很符合直觉。


实验部分

其实DiT的框架已经基本上讲完了,剩下的部分就是DiT相较于unet多么多么厉害,DiT内部的消融实验。这一部分我们简单讲一下。

首先就是DiT相较于unet的高效性,这张图的右边很明显的显示出来了,相同的计算量,transformer有更好的效果。
但实事求是地讲,对于cnn这种复用参数的模型比较计算量有点耍流氓的感觉。
cnn都把参数压缩成卷积核了,参数量大大减少,表现差一点不是很合理吗。。。

关于DiT内部scaling的实验我们不看上面左边的图,直接看下面这张。

很直观的一张图,有点scaling law 的感觉了,但是可惜最后有一点点收敛了。
另外一点是我很好奇作者为什么不实验1个像素变为1个token的配置呢?

众所周知啊,ddpm的时序一般是不会走完的,也就是说会在采样的时候抽几个噪声等级进行去噪,以减少采样时间,所以很显然我们可以看一看采样时间对表现的影响,也就是test scaling law了。

结果倒是挺标致的,基本上符合test time scaling law吧。
可以得到的结论是,较小的模型增加训练时间来提升效果是一件性价比很低的事。


总结

DiT的结构并不复杂,通过vae的encoder将图片下采样到小尺寸的特征图,从而可以将特征图以较高分辨率进行划分,然后将划分后的信息损失较少的的tokens送入transformer进行处理,处理的过程中利用类别标签进行adanorm引导。最后得到的特征图再送入vae的decoder进行解码,产生样本。

但是我觉得仍然不够简介,在图像领域的transformer似乎永远都甩不开cnn进行独立的工作。