SAM就是一类处理图像分割任务的通用模型。与以往只能处理某种特定类型图片的图像分割模型不同,SAM可以处理所有类型的图像。
在SAM出现前,基本上所有的图像分割模型都是专有模型。比如,在医学领域,有专门分割核磁图像的人工智能模型,也有专门分割CT影像的人工智能模型。但这些模型往往只在分割专有领域内的图像时,才具有良好性能,而在分割其他领域的图像时往往性能不佳。
沿着前两篇文章之后,本文讲下面带下划线的三个图像分割模型
DDIM
VisionTransformer
CLIP
DALL·E
MAE
SwinTransformerV2
StableDiffusion
BEiT-3
Midjourney V3
VisualChatGPT
GPT4
Midjourney V5
FastSAM
(中科院版SAM)
MobileSAM
- 在网络数据集上预训练的大语言模型具有强大的zero-shot(零样本)和few-shot(少样本)的泛化能力,这些"基础模型"可以推广到超出训练过程中的任务和数据分布,这种能力通过“prompt engineering”实现
具体就是输入prompt得到有效的文本输出,使用网络上的大量文本资料库进行缩放和训练后,发现这种零样本和少样本下的模型比微调模型效果还要好,数据集越大,效果越明显,比如GPT3 - 视觉任务上也对这种基础模型进行了探索,比如CLIP和ALIGN利用对比学习,将文本和图像编码进行了对齐,一旦训练完成,还可以扩展到下游任务,比如生成图像
而SAM(、)的目的是建立一个图像分割的基础模型,开发一个具有零样本能力的模型
1.2.1 image encoder的构成(ViT)与其编码实现
虽然按照常规来讲,一个图像编码器可以是任何输出图像嵌入的网络,但出于可扩展性的考虑,Meta最终利用MAE预训练的视觉Transformer (即ViT,具体来说是一个ViT-H/16,具有14×14的窗口注意力和四个等间距的全局注意力块,如果忘了ViT长啥样,可回顾此文第4部分),且对其进行了最小化调整以处理高分辨率输入
- 且该encoder在prompt encoder之前,对每张图像只运行一次输入(c,h,w)的图像,对图像进行缩放,按照长边缩放成1024,短边不够就pad,得到(c,1024,1024)的图像(相当于使用通过重新缩放图像并填充较短边获得的1024×1024的输入分辨率)
- 经过image encoder,得到对图像16倍下采样的feature(因此图像嵌入是64×64),大小为(256,64,64)
具体而言,为了减少通道维度,按照[Exploring Plain Vision Transformer Backbones for Object Detection],使用1×1卷积得到256个通道,然后使用3×3卷积也得到256个通道
至于其代码实现主要实现以下几个类
1.2.1.1 定义ImageEncoderViT类
一个是定义ImageEncoderViT类,这是一个基于Vision Transformer的图像编码器,该类从nn.Module继承
- 引入一系列库,且定义ImageEncoderViT类,这是一个基于Vision Transformer的图像编码器
- 创建Transformer的主体,包含多个Transformer block
- 定义前向传播函数
1.2.1.2 定义Block类
这是Transformer的基本组成模块,包括注意力机制和前馈神经网络,该类从nn.Module继承
- 定义Block类,以及分别创建:第一个归一化层、注意力层、第二个归一化层、MLP层
- 前向传播函数如下
1.2.1.3 定义Attention类
这是一个多头注意力机制的块,支持相对位置嵌入,该类从nn.Module继承
- 定义attention类
- 前向传播函数如下
1.2.1.4 定义两个函数 window_partition 和 window_unpartition
用于将输入的张量进行窗口划分和合并,这些函数在 Vision Transformer 的实现中用于实现窗口注意力机制
- window_partition
- window_unpartition
1.2.1.5 定义两个函数 get_rel_pos 和 add_decomposed_rel_pos
这两个函数 get_rel_pos 和 add_decomposed_rel_pos用于处理相对位置嵌入
在 Vision Transformer 的实现中,相对位置嵌入用于提供序列元素之间的相对位置信息,以帮助模型更好地捕捉序列中的关系。这些函数用于生成和应用相对位置嵌入
- get_rel_pos
- add_decomposed_rel_pos
1.2.1.6 定义一个 PatchEmbed 类
用于将图像转换为补丁嵌入。它使用卷积层将输入图像转换为指定维度的补丁嵌入表示。在前向传播中,输入经过卷积层进行投影,并调换维度的顺序,以使得输出为批量-高度-宽度-通道的形状
1.2.2 prompt encoder
分成2类:稀疏的(点/box/文本)、稠密的(mask)
- 稀疏的点、box、文本
对于point
映射到256维的向量,包含:代表点位置的 positional encoding,加2个代表该点是前景/背景的可学习的embedding
Sparse prompts are mapped to 256-dimensional vectorial embeddings as follows. A point is represented as the sum of a positional encoding [95] of the point’s location and one of two learned embeddings that indicate if the point is either in the foreground or background.
对于box
用一个embedding对表示:1) 可学习的embedding代表左上角,2) 可学习的embedding代表右下角
对于文本
通过CLIP模型进行文本编码 - 对于稠密的mask
用输入图像1/4分辨率的mask,然后用(2,2)卷积核,stride-2输出channel为4和16,再用(1,1)卷积核将channel升到256
We input masks at a 4× lower resolution than the input image, then downscale an additional 4× using two 2×2, stride-2 convolutions with output channels 4 and 16, respectively. A final 1×1 convolution maps the channel dimension to 256.
相当于以比输入图像低 4×的分辨率输入掩码,然后使用两个 2×2、步幅为 2 的卷积层将其进一步缩小 4×,输出通道分别为 4 和 16。 最终的1×1卷积将通道维度映射到256
此后,mask 和iamge embedding通过element-wise相乘(逐元素相乘,可以理解成mask的feature对image的feature进行加权)
其代码实现为
- 引入一些库
- 接着
- 之后
- 最后
1.2.3 mask decoder
mask decoder模块:在prompt embeddings中插入一个可学习的token,用于docoder的输出
对于下图的左侧部分,依次进行如下4个步骤(下图左下角到左上角,即从底至上)
- prompt toekns和output tokens进行self attn
self-attention on the tokens - 用得到的token(token作为Q),和image embedding进行 cross attn
cross-attention from tokens (as queries) to the image embedding - point-wise MLP 更新token
a point-wise MLP updates each token - 用image embedding(image embedding作为Q),和步骤3的token进行cross atten
cross-attention from the image embedding (as queries) to tokens
重复上述步骤2次,再将attn再通过残差进行连接,最终输出masks和iou scores,这段的代码实现为
- 首先
- 其次
- 接下来
对于下图的右侧部分
- 运行解码器后,通过两个转置卷积对更新的图像嵌入进行4倍上采样(现在相对于输入图像缩小了4倍)
After running the decoder, we upsample the updated image embedding by 4× with two transposed convolutional 16 layers (now it’s downscaled 4× relative to the input image) - 然后,token再次关注图像嵌入,即将更新的输出token嵌入传递给一个小型的三层MLP,该MLP输出一个与上采样图像嵌入的通道维度匹配的向量
Then, the tokens attend once more to the image embedding and we pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding - 最后,用上采样图像嵌入和 MLP输出之间的空间,做逐点乘积来预测掩码
Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP’s output
其中,有几个问题值得提一下
- transformer使用的嵌入维度为256,MLP块 的内部尺寸较大,为2048,但是MLP仅应用于提示值相对较少(很少大于20)的prompt token
然而,在我们有64× 64图像嵌入的交叉注意力层中,为了提高计算效率,将查询、键和值的通道维度减少了2至128倍,所有的注意力层都使用了8个头
The transformer uses an embedding dimension of 256. The transformer MLP blocks have a large internal dimension of 2048, but the MLP is applied only to the prompt tokens for which there are relatively few (rarely greater than 20). However, in cross-attention layers where we have a 64×64 image embedding, we reduce the channel dimension of the queries, keys, and values by 2× to 128 for computational efficiency. All attention layers use 8 heads. - 用于上采样输出图像嵌入的转置卷积是2×2,步幅2,输出通道维度为64和32,并具有GELU激活,最后通过层归一化将它们分开
The transposed convolutions used to upscale the output image embedding are 2×2, stride 2 with output channel dimensions of 64 and 32 and have GELU activations. They are separated by layer normalization. - 为了解决输出模糊性问题(一个prompt可能生成多个mask,比如衣服上的一个点,既可以表示衣服,也表示穿衣服的人),预测输出多个masks 「即使用少量输出token并同时预测多个掩码,而不是预测单个掩码,默认情况下预测三个掩码,因为三层(整体、部分和子部分)通常足以描述嵌套的掩码,即three layers (whole, part, and subpart) are often enough to describe nested masks」
在训练过程中,只回传最小的loss,为了对mask进行排序,增加一个小的head预测mask和目标的iou
当输入多个提示时,生成的mask会比较接近,为了减少loss退化和确保获取明确的mask,此时只预测一个mask (作为第4个预测mask,只有多个提示时才预测,当单个提示时不用,即This is accomplished by adding a fourth output token for an additional mask prediction. This fourth mask is never returned for a single prompt and is the only mask returned for multiple prompts.)
其代码实现为 (定义一个MaskDecoder类,用于预测给定图像和提示嵌入的掩码,其使用的Transformer架构。同时,也定义了一个MLP类,即多层感知器网络)
- 先是
- 后是
在分别实现了上述结构后,在实际分割时便可以直接调用了
1.2.4 模型训练
训练时模拟交互分割的过程,从目标mask中随机选取前景点或者box,点是从gt mask选取,box增加长边10%的噪声,最大20像素
在第一次prompt预测mask之后,后续是从预测mask和gt mask有差异的区域采样点
- 如果新生成的点是FN,则作为前景
- 如果是FP,则作为背景
同时,将预测的mask(unthresholded mask logits代替二值化的mask,不过滤阈值,默认为0),作为prompt作为迭代
训练过程中,发现用8个采样点比较合适(对比16个,没有明显增益),为了鼓励模型从mask中获益,其中2个迭代不用新采样的点,总共11个迭代,一个是初始化的prompt输入,然后是8个上述迭代,再加2个不重新采样点的迭代(这样可以refine mask)。由于mask decoder比较轻,所以可以进行更多次的迭代
- loss
mask 用focal loss和dice loss进行线性组合,系数(20:1),iou 用mse loss - 训练时间
256 A100 GPUs,3-5天(A100价格6万左右,256个,1000多万,你懂的..)
- 辅助人工标注
通过SAM基于浏览器的交互式分割工具,通过“brush”和"eraser"工具,进行标注。模型可以实时输出mask,建议标注者优先标记他们命名的对象,按图层顺序标记,如果一个mask标记超过30s,先处理下一张
SAM先用公开数据集训练,然后再用新增的标注mask训练。随着数据越多,image-encoder的能力越强,retrained了6次。随着模型改进,每个mask平均标注时间从34s到14s,平均每张图像mask从22增加到44个。在这个过程中,从12万图像中,收集了430万个mask。 - 半自动
增加mask的多样性,首先检测出可信的mask,然后用预测mask填充图像,让标注者标注未标记的mask。为了检测可信的mask,先用第一步的mask训练了一个类别一样的box检测器。半自动过程中,从18万张图像中生成了590万个mask。用新收集的数据,重新训练模型,平均标注时间又回到了34s,因为新的mask都是比较有难度的。每张图像上mask从44增加到72。 - 全自动
利用前2步,得到的大量的和多样性的mask,结合模型可以根据不明确的输入也能输出有效的mask(参考mask encoder),对图像生成(32,32)个格网点,每个点预测一系列mask,如果一个点落在部分、子部分上,模型返回部分、子部分和整体的object。同时,通过预测的iou筛选 confident(可信的mask),选取一个stable的mask(稳定的mask,在相似的mask中,概率阈值在 0.5-δ和 0.5-δ之间);最后,通过nms过滤confident和stable中重复的mask
为了提高mask比较小的,还通过放大图像进行crop,处理多个mask覆盖的情况
最终在1100万数据集上,生成了11亿高质量的mask
数据情况
- 图片:从合作商获取1100万张图像,按短边重采样到1500像素
- mask:99.1%都是自动生成的,通过对比分析,自动生成的mask质量也是非常高的。为了评估质量,随机选500张图像(约5万个mask),让专业的标注人员进行标注,通过对比发现94%的mask有90%以上的iou
- 数据分布更广,从全世界获取数据,mask更多,数据偏向性较小
在上文第一部分,我们已经了解到
Segment Anything 的关键特征是基于提示的视觉 Transformer(ViT)模型,该模型是在一个包含来自 1100 万张图像且超过 10 亿个掩码的视觉数据集 SA-1B 上训练的,可以分割给定图像上的任何目标
尽管有上述优点,但由于 SAM 中的 ViT-H 图像编码器有 632M 个参数(基于提示的解码器只需要 387M 个参数),因此实际使用 SAM 执行任何分割任务的计算和内存成本都很高,这对实时应用来说具有挑战性
- 后续,研究者们也提出了一些改进策略:将默认 ViT-H 图像编码器中的知识提炼到一个微小的 ViT 图像编码器中,或者使用基于 CNN 的实时架构降低用于 Segment Anything 任务的计算成本
- 在最近的一项研究中,Meta 研究者提出了另外一种改进思路 —— 利用 SAM 的掩码图像预训练 (SAMI)。这是通过利用 MAE 预训练方法和 SAM 模型实现的,以获得高质量的预训练 ViT 编码器
该项研究对应的论文链接为:https://arxiv.org/pdf/2312.00863.pdf,此则为其论文主页:https://yformer.github.io/efficient-sam/
// 待更
尽管SAM成功解决了图像中的分割问题,但现有的视频分割模型和数据集在提供类似于“视频中分割任何东西”的能力方面仍显不足
对此,Meta随机发布了Segment Anything Model 2(SAM 2),用于视频和图像分割的统一模型(将图像视为单帧视频)
- 任务的输入可以是视频任意帧上的点、框或掩码,以定义感兴趣的分割部分,并为其预测时空掩码(即“masklet”)。一旦预测出masklet,可以通过在附加帧中提供提示来迭代地优化它
- 在单个图像和视频帧中生成感兴趣对象的分割掩码。SAM 2 配备了一个存储有关对象和先前交互信息的内存,这使得它能够在整个视频中生成掩码预测,并根据之前观察到的帧中存储的对象记忆上下文有效地纠正这些预测
且其流式架构是对视频领域的 SAM 的自然泛化,每次处理一个视频帧,配备了一个记忆注意模块,以关注目标对象的先前记忆。在应用于图像时,内存是空的,模型的行为类似于 SAM
SAM 2 解码器使用的帧嵌入并不是直接来自图像编码器,而是基于过去预测和提示帧的记忆进行条件化「如下图所示,对于给定的帧,分割预测以当前提示和/或先前观察到的记忆为条件。视频以流式方式处理,图像编码器逐帧处理,交叉关注来自之前帧的目标对象记忆。掩码解码器可选择性地接受输入提示,预测该帧的分割掩码。最后,记忆编码器将预测和图像编码器嵌入(图中未显示)转换,以便在未来帧中使用」
- 提示帧也可能相对于当前帧“来自未来”。帧的记忆由记忆编码器根据当前预测创建,并放置在记忆库中以供后续帧使用
- 记忆注意操作从图像编码器获取每帧嵌入,并在掩码解码器将其用于形成预测之前,将其条件化在记忆库上
3.1.1 图像编码器
为了实时处理任意长度的视频,作者采用流处理方法,在视频帧可用时进行处理
- 图像编码器在整个交互过程中只运行一次,其作用是提供表示每个帧的无条件标记(特征嵌入)
- 最终使用了MAE(He et al., 2022)预训练的Hiera(Ryali et al., 2023; Bolya etal., 2023)图像编码器,它是分层的,使得能够在解码过程中使用多尺度特征
3.1.2 记忆注意力
记忆注意力的作用是基于过去帧的特征和预测以及任何新的提示来调整当前帧的特征
- 作者堆叠了L个transformer块,第一个块以当前帧的图像编码作为输入。每个块执行自注意力,然后对存储在记忆库中的(提示/非提示)帧和对象指针(见下文)进行交叉注意力,然后是一个MLP
- 使用普通的注意力操作进行自注意力和交叉注意力,使得能够从最近在高效注意力内核方面的进展中受益(Dao, 2023)
3.1.3 提示编码器和掩码解码器
// 待更
3.1.4 记忆编码器
3.1.5 记忆库
3.1.6 训练
对于SAMURAI
- 其对应的论文为《》
- 其对应的项目地址为:
// 待更