autofocus代码链接:https://github.com/apple/ml-autofocusformer

论文关键部分原理链接:click此处(我的上一篇文章)

《通过空间聚类、局部注意力和自适应采样三部分实现了聚类的计算。这是 AutoFocusFormer 的核心创新。》

0.前言

1)本文中会单拎出来几行关键代码,注:都可以在原本代码中找到,不是额外代码。 2)三个核心创新部分是配套我上一篇文章中论文的阅读来食用的。

1. 流程

看官方代码可以发现,/models/aff_transformer.py中的BasicLayer类说明了聚类的计算过程。 1)首先进行空间分块,将图像划分为多个 cluster,每个 cluster 内的 token 被认为离得较近。这通过 space_filling_cluster 函数实现。 2)找到每个 token 最近的几个 cluster,从这些附近的 cluster 收集 token 形成 neighborhood。这通过 knn_keops 和一些 gather 操作实现。 3)在 attention 计算中,每个 token 只会attend 到其 neighborhood 内的其他 token。这实现了局部的注意力计算。 4)最后通过 ClusterMerging 层,会根据每个 token 的重要性对其进行采样,保留重要的 token,丢弃不重要的 token。这实现了逐步下采样的效果。

2. 空间聚类逻辑代码

  这主要在BasicLayer的forward方法中实现。首先使用space_filling_cluster函数将图像划分为多个cluster,每个cluster内的token被认为离得较近。然后对每个token找到其最近的几个cluster,并从这些附近的cluster收集token形成neighborhood。这通过knn_keops和gather操作实现。

class BasicLayer(nn.Module):

""" AutoFocusFormer layer for one stage.

Args:

dim (int): Number of input channels.

out_dim (int): Number of output channels.

cluster_size (int): Cluster size.

nbhd_size (int): Neighbor size. If larger than or equal to number of tokens, perform global attention;

otherwise, rounded to the nearest multiples of cluster_size.

depth (int): Number of blocks.

num_heads (int): Number of attention heads.

mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.

alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0

ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25

reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True

drop (float, optional): Dropout rate. Default: 0.0

attn_drop (float, optional): Attention dropout rate. Default: 0.0

drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0

norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

layer_scale (float, optional): Layer scale initial parameter. Default: 0.0

downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None

"""

def __init__(self, dim, out_dim, cluster_size, nbhd_size,

depth, num_heads, mlp_ratio,

alpha=4.0, ds_rate=0.25, reserve_on=True,

drop=0., attn_drop=0.,

drop_path=0., norm_layer=nn.LayerNorm,

layer_scale=0.0, downsample=None):

super().__init__()

self.dim = dim

self.nbhd_size = nbhd_size

self.cluster_size = cluster_size

self.depth = depth

# build blocks

self.blocks = nn.ModuleList([

ClusterTransformerBlock(dim=dim,

num_heads=num_heads,

mlp_ratio=mlp_ratio,

drop=drop, attn_drop=attn_drop,

drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,

layer_scale=layer_scale,

norm_layer=norm_layer)

for i in range(depth)])

# merging layer

if downsample is not None:

self.downsample = downsample(dim=dim, out_dim=out_dim, norm_layer=norm_layer, alpha=alpha, ds_rate=ds_rate, reserve_on=reserve_on)

else:

self.downsample = None

# cache the clustering result for the first feature map since it is on grid

self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = None, None, None, None, None

# fc for importance scores

if downsample is not None:

self.prob_net = nn.Linear(dim, 1)

def forward(self, pos, feat, h, w, on_grid, stride):

"""

Args:

pos - b x n x 2, token positions

feat - b x n x c, token features

h,w - max height and width of token positions

on_grid - bool, whether the tokens are still on grid; True for the first feature map

stride - int, "stride" of the current token set; starts with 2, then doubles in each stage

"""

b, n, d = pos.shape

c = feat.shape[2]

assert self.cluster_size > 0, 'self.cluster_size must be positive'

if self.nbhd_size >= n:

global_attn = True

member_idx, cluster_mask = None, None

else:

global_attn = False

k = int(math.ceil(n / float(self.cluster_size))) # number of clusters

nnc = min(int(round(self.nbhd_size / float(self.cluster_size))), k) # number of nearest clusters

nbhd_size = self.cluster_size * nnc

self.nbhd_size = nbhd_size # if not global attention, then nbhd size is rounded to nearest multiples of cluster

if global_attn:

rel_pos = (pos[:, None, :, :]+rel_pos_width) - pos[:, :, None, :] # b x n x n x d

else:

if k == n:

# if number of clusters equal to number of tokens

cluster_mean_pos = pos

member_idx = torch.arange(n, device=feat.device).long().reshape(1, n, 1).expand(b, -1, -1) # b x n x 1

cluster_mask = None

else:

# perform clustering

if on_grid:

if self.cluster_mean_pos is None:

self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False)

pos, cluster_mean_pos, member_idx, cluster_mask = self.pos[:b], self.cluster_mean_pos[:b], self.member_idx[:b], self.cluster_mask

# reorder the tokens so that tokens in same cluster are stored together

feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), self.reorder[:b].view(-1)].reshape(b, n, c)

if cluster_mask is not None:

cluster_mask = cluster_mask[:b]

else:

pos, cluster_mean_pos, member_idx, cluster_mask, reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False)

# reorder the tokens so that tokens in same cluster are stored together

feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), reorder.view(-1)].reshape(b, n, c)

assert member_idx.shape[1] == k and member_idx.shape[2] == self.cluster_size, "member_idx shape incorrect!"

nearest_cluster = knn_keops(pos, cluster_mean_pos, nnc) # b x n x nnc

# collect neighbor indices from nearest clusters

m = self.cluster_size

member_idx = member_idx.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size) # b x n x nnc*m

if cluster_mask is not None:

cluster_mask = cluster_mask.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size)

pos_ = pos.gather(index=member_idx.view(b, -1, 1).expand(-1, -1, d), dim=1).reshape(b, n, nbhd_size, d)

rel_pos = pos_ - (pos.unsqueeze(2)-rel_pos_width) # b x n x nbhd_size x d

# compute indices in the position embedding lookup table

pe_idx = (rel_pos[..., 1] * table_width + rel_pos[..., 0]).long()

for i_blk in range(len(self.blocks)):

blk = self.blocks[i_blk]

feat = blk(feat=feat,

member_idx=member_idx,

cluster_mask=cluster_mask,

pe_idx=pe_idx,

global_attn=global_attn)

if self.downsample is not None:

learned_prob = self.prob_net(feat).sigmoid() # b x n x 1

reserve_num = math.ceil(h/(stride*2)) * math.ceil(w/(stride*2))

pos, feat = self.downsample(pos=pos, feat=feat,

member_idx=member_idx, cluster_mask=cluster_mask,

learned_prob=learned_prob, stride=stride,

pe_idx=pe_idx, reserve_num=reserve_num)

return pos, feat

def extra_repr(self) -> str:

return f"dim={self.dim}, depth={self.depth}"

  在point_utils.py文件中定义了space_filling_cluster的函数

def space_filling_cluster(pos, m, h, w, no_reorder=False, sf_type='', use_anchor=True):

"""

The balanced clustering algorithm based on space-filling curves

In the case where number of tokens not divisible by cluster size,

the last cluster will have a few blank spots, indicated by the mask returned

Args:

pos - b x n x 2, positions of tokens

m - int, target size of the clusters

h,w - int, height and width

no_reorder - bool, if True, return the clustering based on the original order of tokens;

otherwise, reorder the tokens so that the same cluster stays together

sf_type - str, can be 'peano' or 'hilbert', or otherwise, horizontal scanlines w/ alternating

direction in each row by default

use_anchor - bool, whether to use space-fiiling anchors or not; if False, directly compute

space-filling curves on the token positions

Returns:

pos - b x n x 2, returned only if no_reorder is False; the reordered position of tokens

cluster_mean_pos - b x k x 2, the clustering centers

member_idx - b x k x m, the indices of tokens in each cluster

cluster_mask - b x k x m, the binary mask indicating the paddings in last cluster (0 if padding)

pos_ranking - b x n x 1, returned only if no_reorder is False; i-th entry is the idx of the token

rank i in the new order

"""

with torch.no_grad():

pos = pos.detach()

if pos.dtype != torch.float:

pos = pos.to(torch.float)

b, n, d = pos.shape

k = int(math.ceil(n/m))

if use_anchor:

patch_len = (h*w/k)**0.5

num_patch_h = int(round(h / patch_len))

num_patch_w = int(round(w / patch_len))

patch_len_h, patch_len_w = h / num_patch_h, w / num_patch_w

if sf_type == 'peano':

num_patch_h = max(3, int(3**round(math.log(num_patch_h, 3))))

patch_len_h = h / num_patch_h

num_patch_w = int(round(w / h * 3) * (num_patch_h / 3))

patch_len_w = w / num_patch_w

elif sf_type == 'hilbert':

num_patch_h = max(2, int(2**round(math.log(num_patch_h, 2))))

patch_len_h = h / num_patch_h

num_patch_w = int(round(w / h * 2) * (num_patch_h / 2))

patch_len_w = w / num_patch_w

hs = torch.arange(0, num_patch_h, device=pos.device)

ws = torch.arange(0, num_patch_w, device=pos.device)

ys, xs = torch.meshgrid(hs, ws)

grid_pos = torch.stack([xs, ys], dim=2) # h x w x 2

grid_pos = grid_pos.reshape(-1, 2)

# sort the grid centers to one line

if sf_type == 'peano':

order_grid_idx, order_idx = calculate_peano_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0))

order_grid_idx = order_grid_idx[0]

order_idx = order_idx[0]

elif sf_type == 'hilbert':

order_grid_idx, order_idx = calculate_hilbert_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0))

order_grid_idx = order_grid_idx[0]

order_idx = order_idx[0]

else:

order_mask = torch.ones_like(ys) # h x w

order_mask[1::2] = -1

order_mask = order_mask * xs

order_mask = order_mask + ys*w

order_mask[1::2] += (w-1)

order_mask = order_mask.reshape(-1)

order_idx = order_mask.sort()[1]

order_idx_src = torch.arange(len(order_idx)).to(pos.device)

order_grid_idx = torch.zeros_like(order_idx_src)

order_grid_idx.scatter_(index=order_idx, dim=0, src=order_idx_src)

ordered_grid = grid_pos[order_idx]

patch_len_hw = torch.Tensor([patch_len_w, patch_len_h]).to(pos.device)

init_pos_means = ordered_grid * patch_len_hw + patch_len_hw/2 - 0.5

nump = ordered_grid.shape[0]

prev_means = torch.zeros_like(init_pos_means)

prev_means[1:] = init_pos_means[:nump-1].clone()

prev_means[0] = prev_means[1] - (prev_means[2]-prev_means[1]) # float('inf')

next_means = torch.zeros_like(init_pos_means)

next_means[:nump-1] = init_pos_means[1:].clone()

next_means[-1] = next_means[-2] + (next_means[-2]-next_means[-3]) # float('inf')

mean_assignment = (pos / patch_len_hw).floor()

mean_assignment = mean_assignment[..., 0] + mean_assignment[..., 1] * num_patch_w

mean_assignment = order_grid_idx.unsqueeze(0).expand(b, -1).gather(index=mean_assignment.long(), dim=1).unsqueeze(2) # b x n x 1

prev_mean_assign = prev_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1) # b x n x d

next_mean_assign = next_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1) # b x n x d

dist_prev = (pos-prev_mean_assign).pow(2).sum(-1) # b x n

dist_next = (pos-next_mean_assign).pow(2).sum(-1)

dist_ratio = dist_prev / (dist_next + 1e-5)

pos_ranking = mean_assignment * (dist_ratio.max()+1) + dist_ratio.unsqueeze(2)

pos_ranking = pos_ranking.sort(dim=1)[1] # b x n x 1

else:

if sf_type == 'peano':

_, pos_ranking = calculate_peano_order(h, w, pos)

elif sf_type == 'hilbert':

_, pos_ranking = calculate_hilbert_order(h, w, pos)

else:

hs = torch.arange(0, h, device=pos.device)

ws = torch.arange(0, w, device=pos.device)

ys, xs = torch.meshgrid(hs, ws)

order_mask = torch.ones_like(ys) # h x w

order_mask[1::2] = -1

order_mask = order_mask * xs

order_mask = order_mask + ys*w

order_mask[1::2] += (w-1)

order_mask = order_mask.reshape(-1)

pos_idx = pos[..., 0] + pos[..., 1] * w

order_mask = order_mask.gather(index=pos_idx.long().reshape(-1), dim=0).reshape(b, n)

pos_ranking = order_mask.sort()[1]

pos_ranking = pos_ranking.unsqueeze(2)

pos = pos.gather(index=pos_ranking.expand(-1, -1, d), dim=1) # b x n x d

if k*m == n:

cluster_mask = None

cluster_mean_pos = pos.reshape(b, k, -1, d).mean(2)

else:

pos_pad = torch.zeros(b, k*m, d, dtype=pos.dtype, device=pos.device)

pos_pad[:, :n] = pos.clone()

cluster_mask = torch.zeros(b, k*m, device=pos.device).long()

cluster_mask[:, :n] = 1

cluster_mask = cluster_mask.reshape(b, k, m)

cluster_mean_pos = pos_pad.reshape(b, k, -1, d).sum(2) / cluster_mask.sum(2, keepdim=True)

if no_reorder:

if k*m == n:

member_idx = pos_ranking.reshape(b, k, m)

else:

member_idx = torch.zeros(b, k*m, device=pos.device, dtype=torch.int64)

member_idx[:, :n] = pos_ranking.squeeze(2)

member_idx = member_idx.reshape(b, k, m)

return cluster_mean_pos, member_idx, cluster_mask

else:

member_idx = torch.arange(k*m, device=pos.device)

member_idx[n:] = 0

member_idx = member_idx.unsqueeze(0).expand(b, -1) # b x k*m

member_idx = member_idx.reshape(b, k, m)

return pos, cluster_mean_pos, member_idx, cluster_mask, pos_ranking

def knn_keops(query, database, k, return_dist=False):

"""

Compute k-nearest neighbors using the Keops library

Backward pass turned off; Keops does not provide backward pass for distance

Args:

query - b x n_ x c, the position of tokens looking for knn

database - b x n x c, the candidate tokens for knn

k - int, the nunmber of neighbors to be found

return_dist - bool, whether to return distance to the neighbors

Returns:

nn_dix - b x n x k, the indices of the knn

nn_dist - b x n x k, if return_dist, the distance to the knn

"""

b, n, c = database.shape

with torch.no_grad():

query = query.detach()

database = database.detach()

# Keops does not support half precision

if query.dtype != torch.float32:

query = query.to(torch.float32)

if database.dtype != torch.float32:

database = database.to(torch.float32)

from pykeops.torch import LazyTensor

query_ = LazyTensor(query[:, None, :, :])

database_ = LazyTensor(database[:, :, None, :])

dist = ((query_-database_) ** 2).sum(-1) ** 0.5 # b x n x n_

if return_dist:

nn_dist, nn_idx = dist.Kmin_argKmin(k, dim=1) # b x n_ x k

return nn_idx, nn_dist

else:

nn_idx = dist.argKmin(k, dim=1) # b x n_ x k

return nn_idx

  gather操作是PyTorch中的一个张量操作函数,从输入张量的指定索引处收集元素。

3. 局部注意力

  在计算attention时,每个token只会attend到其neighborhood内的其他token。这通过在ClusterAttention中输入member_idx和cluster_mask来实现局部注意力。

class ClusterAttention(nn.Module):

"""

Performs local attention on nearest clusters

Args:

dim (int): Number of input channels.

num_heads (int): Number of attention heads.

attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0

proj_drop (float, optional): Dropout ratio of output. Default: 0.0

"""

def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.):

super().__init__()

self.dim = dim

self.pos_dim = 2

self.num_heads = num_heads

head_dim = dim // num_heads

self.scale = head_dim ** -0.5

self.q = nn.Linear(dim, dim)

self.kv = nn.Linear(dim, 2*dim)

self.softmax = nn.Softmax(dim=-1)

self.blank_k = nn.Parameter(torch.randn(dim))

self.blank_v = nn.Parameter(torch.randn(dim))

self.pos_embed = nn.Linear(self.pos_dim+3, num_heads)

self.attn_drop = nn.Dropout(attn_drop)

self.proj = nn.Linear(dim, dim)

self.proj_drop = nn.Dropout(proj_drop)

def forward(self, feat, member_idx, cluster_mask, pe_idx, global_attn):

"""

Args:

feat - b x n x c, token features

member_idx - b x n x nbhd, token idx in each local nbhd

cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid)

pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table

global_attn - bool, whether to perform global attention

"""

b, n, c = feat.shape

c_ = c // self.num_heads

assert c == self.dim, "dim does not accord to input"

h = self.num_heads

# get qkv

q = self.q(feat) # b x n x c

q = q * self.scale

kv = self.kv(feat) # b x n x 2c

# get attention

if global_attn:

q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3) # b x h x n x c_

kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4) # 2 x b x h x n x c_

key, v = kv[0], kv[1]

attn = q @ key.transpose(-1, -2) # b x h x n x n

mask = None

else:

nbhd_size = member_idx.shape[-1]

m = nbhd_size

q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3)

kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4) # 2 x b x h x n x c_

key, v = kv[0], kv[1]

attn = CLUSTENQKFunction.apply(q, key, member_idx) # b x h x n x m

mask = cluster_mask

if mask is not None:

mask = mask.reshape(b, 1, n, m)

# position embedding

global pre_table

if not pre_table.is_cuda:

pre_table = pre_table.to(pe_idx.device)

pe_table = self.pos_embed(pre_table) # 111 x 111 x h for img_size 224x224

pe_shape = pe_idx.shape

pos_embed = pe_table.gather(index=pe_idx.view(-1, 1).expand(-1, h), dim=0).reshape(*(pe_shape), h).permute(0, 3, 1, 2)

attn = attn + pos_embed

if mask is not None:

attn = attn + (1-mask)*(-100)

# blank token

blank_attn = (q * self.blank_k.reshape(1, h, 1, c_)).sum(-1, keepdim=True) # b x h x n x 1

attn = torch.cat([attn, blank_attn], dim=-1)

attn = self.softmax(attn)

attn = self.attn_drop(attn)

blank_attn = attn[..., -1:]

attn = attn[..., :-1]

blank_v = blank_attn * self.blank_v.reshape(1, h, 1, c_) # b x h x n x c_

# aggregate v

if global_attn:

feat = (attn @ v).permute(0, 2, 1, 3).reshape(b, n, c)

feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c)

else:

feat = CLUSTENAVFunction.apply(attn, v, member_idx).permute(0, 2, 1, 3).reshape(b, n, c)

feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c)

feat = self.proj(feat)

feat = self.proj_drop(feat)

return feat

def extra_repr(self) -> str:

return f'dim={self.dim}, num_heads={self.num_heads}'

  在ClusterAttention的forward函数中,有这么一段代码:

if global_attn:

# 全局attention

else:

# 局部attention

attn = CLUSTENQKFunction.apply(q, key, member_idx)

mask = cluster_mask

  当进行局部注意力时,会调用CLUSTERQKFunction来计算attention。这个Function需要传入member_idx,它表示每个token的neighbor索引。CLUSTERQKFunction内部会根据member_idx来采样key向量,从而只计算局部注意力。cluster_mask用于在最后的softmax前屏蔽无效的neighbor,使得attn值很小,实现准确的局部聚焦。(CLUSTERQKFunction函数在/clusten/src/clusten.py中定义)

4. 自适应采样

  在BasicLayer的末尾,会根据每个token的重要性对其进行采样,保留重要的token,丢弃不重要的token。这是通过ClusterMerging层实现的。该层包含三部分: (1) 根据位置先验和token的学习到的importance score计算每个token的保留概率。 (2) 根据保留概率采样出保留的token。 (3) 对保留的token所在的neighborhood进行采样合并,生成新的、更稀疏的feature map。

class ClusterMerging(nn.Module):

r""" Adaptive Downsampling.

Args:

dim (int): Number of input channels.

out_dim (int): Number of output channels.

norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0

ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25

reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True

"""

def __init__(self, dim, out_dim, norm_layer=nn.LayerNorm, alpha=4.0, ds_rate=0.25, reserve_on=True):

super().__init__()

self.dim = dim

self.pos_dim = 2

self.alpha = alpha

self.ds_rate = ds_rate

self.reserve_on = reserve_on

# pointconv

inner_ch = 4

self.weight_net = nn.Sequential(

nn.Linear(self.pos_dim+3, inner_ch, bias=True),

nn.LayerNorm(inner_ch),

nn.GELU()

)

self.norm = norm_layer(inner_ch*dim)

self.linear = nn.Linear(dim*inner_ch, out_dim)

def forward(self, pos, feat, member_idx, cluster_mask, learned_prob, stride, pe_idx, reserve_num):

"""

Args:

pos - b x n x 2, token positions

feat - b x n x c, token features

member_idx - b x n x nbhd, token idx in each local nbhd

cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid)

learned_prob - b x n x 1, learned importance scores

stride - int, "stride" of the current feature map, 2,4,8 for the 3 stages respectively

pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table

reserve_num - int, number of tokens to be reserved

"""

b, n, c = feat.shape

d = pos.shape[2]

keep_num = int(n*self.ds_rate)

# grid prior

if stride == 2: # no ada ds yet, no need ada grid

grid_prob = ((pos % stride).sum(-1) == 0).float() # b x n

else:

_, min_dist = knn_keops(pos, pos, 2, return_dist=True) # b x n x 2

min_dist = min_dist[:, :, 1] # b x n

ada_stride = 2**(min_dist.log2().ceil()+1) # b x n

grid_prob = ((pos.long() % ada_stride.unsqueeze(2).long()).sum(-1) == 0).float() # b x n

final_prob = grid_prob

# add importance score

if learned_prob is not None:

lp = learned_prob.detach().view(b, n)

lp = lp * self.alpha

final_prob = final_prob + lp

# reserve points on a coarse grid

if self.reserve_on:

reserve_mask = ((pos % (stride*2)).sum(-1) == 0).float() # b x n

final_prob = final_prob + (reserve_mask*(-100))

sample_num = keep_num - reserve_num

else:

sample_num = keep_num

# select topk tokens as merging centers

sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1] # b x n_

if self.reserve_on:

reserve_idx = reserve_mask.nonzero(as_tuple=True)[1].reshape(b, reserve_num)

idx = torch.cat([sample_idx, reserve_idx], dim=-1).unsqueeze(2) # b x n_ x 1

else:

idx = sample_idx.unsqueeze(2)

n = idx.shape[1]

assert n == keep_num, "n not equal to keep num!"

# gather pos, nbhd, nbhd position embedding, nbhd importance scores for topk merging locations

pos = pos.gather(index=idx.expand(-1, -1, d), dim=1) # b x n' x d

nbhd_size = member_idx.shape[-1]

member_idx = member_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m

pe_idx = pe_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m

if cluster_mask is not None:

cluster_mask = cluster_mask.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m

if learned_prob is not None:

lp = learned_prob.gather(index=member_idx.view(b, -1, 1), dim=1).reshape(b, n, nbhd_size, 1) # b x n x m x 1

# pointconv weights

global pre_table

if not pre_table.is_cuda:

pre_table = pre_table.to(pe_idx.device)

weights_table = self.weight_net(pre_table) # 111 x 111 x ic

weight_shape = pe_idx.shape

inner_ch = weights_table.shape[-1]

weights = weights_table.gather(index=pe_idx.view(-1, 1).expand(-1, inner_ch), dim=0).reshape(*(weight_shape), inner_ch)

if learned_prob is not None:

if cluster_mask is not None:

lp = lp * cluster_mask.unsqueeze(3)

weights = weights * lp

else:

if cluster_mask is not None:

weights = weights * cluster_mask.unsqueeze(3)

# merge features

feat = CLUSTENWFFunction.apply(weights, feat, member_idx.view(b, n, -1)).reshape(b, n, -1) # b x n x ic*c

feat = self.norm(feat)

feat = self.linear(feat) # b x n x 2c

return pos, feat

  在ClusterMerging的forward函数中,首先计算一个grid_prob,这是基于token的位置先验得到的保留概率。然后从上一层的BasicLayer传递下来一个learned_prob,这是每个token的重要性分数。将二者组合可以得到每个token的最终保留概率final_prob。根据final_prob使用topk采样出保留的tokens,存储索引在sample_idx中。

sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1]

  同时考虑到要在不同缩放下保留一定比例的anchors,会增加reserve_mask并组合到最终采样中。根据sample_idx索引,收集保留tokens的位置、附近邻居索引、附近邻居的位置embedding等信息。然后使用收集到的附近邻居对保留tokens进行加权平均,生成新的特征表示。这里使用了CLUSTENWFFunction来高效实现索引访问和特征合并。(CLUSTENWFFunction函数在/clusten/src/clusten.py中定义)   综上,ClusterMerging通过计算保留概率、采样保留tokens以及邻域融合三步来逐步生成稀疏的特征图,这实现了自适应采样。

(注:1. 解码头不做介绍    2. 复现说明:主要是在参考作者思路,取出关键部分,改进局部注意力,具体autofocusformer参考官方README.md)

好文链接

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: