计算机视觉实践 unet lits 医学图像分割实践
前言
前段时间看了很多计算机视觉的里程碑式的工作,感觉知识储备还是增加了很多的。
但是因为大量的理论输入,感觉自己本就不多的实践技能由要疏忽了,所以趁着这次的课程作业的机会增加一下自己的实操经验。
因为目的是增加实操经验,所以在编写代码的时候我给自己设定了几个原则:
1.合理的调用深度:我曾经很喜欢cleanrl类型的一个py文件完成所有工作,并且尽量不定义函数,但是这是很不现实的,所以我给自己设定的原则是调用深度至多为一,也就是说函数之间不能嵌套(只考虑自定义函数)。并且项目不设计二级文件夹存放代码,所有代码文件都在根目录。
2.尽量少的库函数调用:在代码实现过程中尽量只是用pytorch和几个基础库,避免学习成本以及可能的依赖问题。当然主要原因是为了手搓增加实战经验。
因为是实践记录并且实现内容比较初级,所以我就有选择性的进行叙述。
实践仓库:abstcol/medical_image_segmentation: the medical image segmentation practice
nii文件格式
本次处理的数据格式是nii。
nii文件的结构是灰度图在扫描的维度上进行堆叠而形成的三维张量,这个特殊的文件结构为处理带来了特别的挑战。
我一开始尝试将所有的nii文件一次性读入内存之后通过自定义的dataset进行取item的操作,后来发现内存根本存不下这么多的图片文件。
这种操作其实来源于我之前写DRL代码的习惯,因为DRL的数据集一般都比较小,并且是收集后立刻就进行训练的,所以不会考虑内存这一层。
之后我尝试将nii文件的索引保存在dataset中,之后根据索引需要时再读取。具体的操作就是在数据集初始化时统计每个nii文件灰度图的数量,累加得到总的数据量,同时为每个数据区间构造字典映射到对应文件以及对应张数。
我后来放弃了这个方案,因为这样做就无法利用多进程读取数据了,会出现两个样本在同一个nii文件中所以两个进程同时读取一个文件的情况,会报错。我尝试了通过进程锁之类的操作规避这个问题,但那超出我的能力范围了,于是放弃。
(事后反思,其实在这个任务中的batchsize很小,完全不需要多进程读取数据)
最后我的选择是将nii文件预处理为一张张的png灰度图,直接当作普通的图片数据进行训练。
另外值得一提的是我找到的lits数据集的测试集是没有标注的,无法使用。
所以我并没有进行测试。
网络设计
直至这篇文章写作的时候,作者刚刚完成baseline的网络设计,也就是朴素的unet架构。
unet框架十分的简单,唯一让我头疼的就是子网络的存储或者说明明。
我把整个unet分成了五部分:
下采样的通道扩张,下采样的尺寸压缩,bottleneck的通道变化,上采样的通道压缩,上采样的尺寸扩张。
其中的每一部分都是模型块序列。
其中最麻烦的就是如何设计初始化的参数接口,我的方案是给出每一层的channel数即可。
另外值得一提的就是网络设计当中的正则化。
我一开始只使用了batch normalization 作为正则项。当然,其实我的batchsize很小,只有6,所以使用bn还是ln效果好值得商榷,但是我没有做仔细的实验,因为我不觉得在这里做实验得到的经验能够泛化到别的领域。
在这之后我尝试了增加dropout层作为更进一步的正则化,效果很不错。
虽然并没有让模型的表现更好,但dropout明显的降低了过拟合程度,这是很优良的性质,这意味着我们可以通过增加训练轮次得到更好的结果。
实验图如下:
(粉色是增加dropout的unet,紫色则是没加的)
数据预处理/数据增强
这一部分也是挺麻烦的。
因为我并没有使用nii的原始数据进行训练,而是先将nii切分为灰度图,再用灰度图进行训练,所以我的数据在一开始就已经是被归一化到0-255的(不包括分割图)
在正式训练的时候我只对图像进行了有限度的处理:
归一化到0-1,一定概率左右翻转,一定概率上下翻转,以及一定概率随机缩放。
归一化到0-1是预处理的范畴,这是为了网络训练的稳定性,它的作用我认为是很大的。
至于说后面三个数据增强方法的增加模型泛化性的功效我则是存疑,因为在我的认知里医学图像的分布是很固定的,基本都是人躺在固定的地方,然后扫描,所以增强后的分布我认为可能永远都不会再测试集碰到。
但是我还是认为这三个数据增强方法在本次实验中起到了避免过拟合,也就是正则化的作用。
nii文件中的数据间的correlation是很大的,往往相邻图片的变化只有一点点,很容易造成模型的过拟合。而数据增强则可以削减这种correlation。
还值得一提的就是torch vision并不支持稠密预测任务的数据增强,所以我自定义了一系列的预处理模块。这也是为了减少所用的额外的外部库。
损失函数
损失函数的选择也很有说法。
在sam的论文中我其实就了解了diceloss,但是一直对于它解决正负样本不平衡的功能没有什么概念,现在自己实操后才发现没它不行。
lits数据集的掩码区域是病变的部分,占总的图片的比例很小,所以如果仅仅使用二值交叉熵函数的话模型很容易忽略正样本(掩码区域)的贡献,一刀切大家都预测为零。
这其实可以通过为正样本增加权重来解决,pytorch也提供了这样的接口。但是一个很严重的问题就是权重增加多少才合适呢?也许一个医学领域的专家可以解决这个问题,但是我肯定是不知道如何设置这个权重超参的。
dice loss 则很好的解决了这个问题,既然普通的二值交叉熵中的正样本会被忽略,那就在加一个只考虑正样本的损失函数得了。这样一来正负样本的重要性就可以简单的通过二值交叉熵损失和diceloss 的权重来体现。
我觉得这是一种朴素方法论的体现:尽量避免设置受先验影响过大的参数。
评估指标
损失函数一般不作为评估指标(虽然dice loss 本身就是很好的评估指标),于是我这里将正样本面积的iou作为评估指标。
因为大量的样本都是没有掩码的,所以我这里的设置是在对iou进行批次的平均时,只考虑有掩码的样本对,如果一个批次中的所有样本都没有掩码,则不进行计算。
这里我为了避免指标估计影响训练,并没有个批次都进行指标估计,而是间隔十个样本估计一次指标。事后反思这是完全没有必要的,相较于训练的耗时,计算iou的耗时极短。
这里还有一个小小的抉择,关于评估集的iou计算是每一个epoch进行一次,还是单个epoch内固定间隔就进行一次?
我之所以认为这是抉择是因为数据集的样本数太多了,训练一个epoch极为耗时,并且任务简单,4个epoch就可以收敛。(这其中也有有很多样本过于相似的原因)
我的决定是每一个epoch进行一次评估,这虽然会导致模型能力估计在训练维度上的粗粒度,但能极大的减少评估所耗的时间。
稀疏化处理尝试
我在第一次实验中并没有使用初始数据集中的全部数据,而是对于nii文件中没10张灰度图抽一张作为数据集。
我的意图是减少数据集样本之间的correlation。我认为哪怕数据集缩小到1/10,通过增加训练轮次可以达到很好的效果。
之后实验结果很差,所以我放弃了这样做。
现在仔细思考我发现这是一个很不明智的做法:增加训练轮次相当于复制数据集,这还不如相似的样本呢。
ViT 架构尝试
用ViT作为backbone,decoder部分则是参照dpt论文([2103.13413] Vision Transformers for Dense Prediction)构造的。
但是我并没有完全参照dpt的架构。
在dpt原论文中的clstoken我直接舍弃了,毕竟原论文消融实验证明这东西根本不重要。
ViT的实现我是参照的FrancescoSaverioZuppichini/ViT: Implementing Vi(sion)T(transformer),这个仓库的实现比较简洁,但是还是有一点问题,那就是类定义嵌套的层数太多了,所以我做出了相应的修改。
至于dpt head 的实现则是根据gpt进行修改的,我实在不想在原作者那错综复杂的结构里抽取出我想要的结构,我觉得这太难了。
实验的结果比较的差,我推测这是因为我的位置编码是可学习参数的原因,所以尝试了用rope代替位置编码,效果果然变好了很多。
评估集表现如下:
实际上ViT的参数量比unet多出了50%(150m vs 100m),但是训练速度远小于unet,我想这是cnn参数复用所带来的必然结果。
其实ViT可以调整的地方还有很多,但是因为训练比较耗时,我就不进行了。
总结
这是一个很朴素的医学图像分割实现,主要是为了增加一点代码经验。
这次实现的问题主要是参数调整的接口不完善,特别是网络的超参得亲自到网络的文件里进行调整,没有集成到args文件。但如果集成到args里,args文件就会变的很复杂。
所以最好的方法是在每个net文件开头定义一个网络的args类,因为时间原因,这里就不实现了。
后补
完成之后和组员讨论发现我的数据处理就是错的。(难绷
lits数据集是一个多类别的语义分割,总共有三个类别:
0.后景
1.正常目标器官
2.癌变目标器官
我之前的处理是把他当作二分类来做的,并且把3当成后景处理。(所以验证集的分割图会有大块缺失,我当时还没在意。唉)
所以我这次把相关损失函数、网络进行了修改,主要就是把单通道的二值预测变成多通道,难点不多,就是操作的时候会有些小细节需要注意。
与此同时,我还做了一些小改动。
新的比较标准
之前的实验都是将unet和dpt训练相同epoch,这样做的依据是两者的参数量类似。
但是这样不太公平,因为cnn可以复用参数,所以unet的推理时间是dpt的两倍。
所以这次我将dpt的epoch设置为unet的两倍,
结果如下:
(以epoch为横轴)
(以时间为横坐标)
可以看的出来,这个dpt的表现还不错嘛,我甚至怀疑再给他几个epoch还能涨点。
新的损失函数
原来的损失函数中的diceloss 是在单个样本内部各个类别的损失求平均,之后再再各个样本间求平均,这很容易让本来就不容易出现的类别2:癌变目标器官的损失被进一步的稀释,所以实验中的2的iou一直不高,远远小于1.
为了解决被稀释的问题我将diceloss的平均顺序变换了一下,先对样本间求平均(注意,如果某个样本的某类标签不存在,就跳过他,防止这类标签被稀释),再在类别间求平均,这样可以防止稀缺掩码的损失被常见掩码稀释。
值得注意的是,我在这个版本中的diceloss中直接忽略了类别1的损失,后景的修正就交给交叉熵。
结果如下
我只能说效果很难评,虽然癌变组织有小小的涨点,但是普通组织降点了,更难以理解的是为什么dpt的训练这么不稳定?
我感觉cnn还是适合处理图像,transformer的embedding这一步的信息损失太大了,带来了表现(尤其是对小区域的分割)的下降。
这坚定了我的一个信念,与其成为微操领域大神,还不如多学一些好用的框架,从宏观上做出改变。
学习率计划
学习率的衰减计划应该是很重要的一个超参数,但是之前一直没有讨论,最近也是做了一些这个部分的实验,所以将这部分的分析补上。
首先尝试的是带有warm up 的余弦退火,warm up可以防止初始训练的不稳定,之后的学习率衰减也有着相同的作用。
效果如下:
(橙、黄色曲线表示同样配置即学习率不变,只是重复实验两次,蓝绿色线表示加了余弦退火)
从1、2两类的估计iou可以看出余弦退火的学习率计划在前期反而导致了2类正确率的下降,猜测是由于保守的学习策略加剧了样本类别不平衡的影响。
而在训练后期余弦退火确实能有效避免过拟合、训练不稳定等现象。
但是实事求是的讲,这里的对比图很难完全体现余弦学习率的作用。虽然余弦退火的训练多出了4个epoch,但是整体的正确率没有上升。
之后尝试的是带warm restart的余弦退火,它与传统余弦退火的区别是带有多次的学习率“重启”,在我看来这样的操作有着跳出局部最优的效果。
(蓝色线为带warm restart 的余弦退火)
但是重启带来的训练的不稳定,这是符合预期的。
在这之外我们尝试了ema对于模型集成的效果,结果如下:
(深绿色的线代表consine+ema)
可以看到效果非常的好,可以说是达到了训练稳定的同时保证了指标不下降。
这预示着也许cosine_restart+ema会是一个不错的选择!