文章

Vit_block

来自croco的Transformer Block, 类似于ViT或Swin Transformer的基本块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
        super().__init__()
        self.norm1 = norm_layer(dim)  # 归一化
        self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)  # 注意力, 旋转位置编码RoPE
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, xpos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos)) 
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

###

forward中将注意力计算的结果与输入$x$相加, 构成残差连接。

本文由作者按照 CC BY 4.0 进行授权