文章

dust3r论文阅读与代码

ViT encoder

在dust3r仓库中,ViT编码器的实现可以在文件 dust3r/model.py 中找到。该文件定义了一个名为 AsymmetricCroCo3DStereo 的类,该类包含了 _encode_image 和 _encode_image_pairs 方法,负责图像特征提取和编码工作。

1
2
3
4
5
6
7
def _encode_image(self, image, true_shape):
    x, pos = self.patch_embed(image, true_shape=true_shape)  # patch_embed负责将输入图像分割为小的 patch 并嵌入到特征空间中。
    assert self.enc_pos_embed is None
    for blk in self.enc_blocks:
        x = blk(x, pos)
    x = self.enc_norm(x)
    return x, pos, None

_encode_image 方法将输入图像分割为 patch,然后将其输入到 Transformer 编码器中。 位置编码信息 (pos) 被添加到特征中,并通过编码器块逐步处理。

patchify

1
2
3
# https://github.com/naver/dust3r/blob/a53d073bb07455f5a3e7fcbeaaea240c159c4f08/dust3r/model.py#L88
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
        self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)

其在AsymmetricCroCo3DStereo的父类CroCoNet的__init___中初始化

1
2
        # croco/models/croco.py -> CroCoNet -> __init__()
        self._set_patch_embed(img_size, patch_size, enc_embed_dim)

get_patch_embed是一个工厂函数,根据参数patch_embed_cls动态地实例化不同的patch embedding模块。可以在下面的代码块中看到使用eval将函数名直接转换为python表达式进行执行。

1
2
3
4
def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
    assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
    patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)  # eval('PatchEmbedDust3R')
    return patch_embed

PatchEmbedDust3R

1
2
3
4
5
6
7
8
9
10
11
class PatchEmbedDust3R(PatchEmbed):
    def forward(self, x, **kw):
        B, C, H, W = x.shape  # 确保图像的长宽能够被patch size整除
        assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
        assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
        x = self.proj(x)
        pos = self.position_getter(B, x.size(2), x.size(3), x.device)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x, pos

可以看到有两个assert来保证H, W必须是patch_size的整数倍。

输入经过一个线性层, 并获取位置编码pos, 如果flatten开关启动, 则需要将HW通道展开。最后输入归一化层。返回patchify后的特征x和位置编码pos.

位置编码方式是可选的,在dust3r/croco/models/croco.py中可以根据输入pos_embed选择cosine或RoPE{xx}等(如RoPE100)

Encoder

本文由作者按照 CC BY 4.0 进行授权