对 NT-Xent 损失的直观解释,并逐步解释操作和我们在 PyTorch 中的实现

先来看一个公式

l

i

,

j

=

log

exp

(

sin

(

z

i

,

z

j

)

/

τ

)

k

=

1

2

N

1

[

k

i

]

exp

(

sin

(

z

i

,

z

k

)

/

τ

)

\mathbb{l}_{i,j}=-\log\frac{\exp(\sin(\mathbf{z}_i,\mathbf{z}_j)/\tau)}{\sum_{k=1}^{2N}1_{[k\neq i]}\exp(\sin(\mathbf{z}_i,\mathbf{z}_k)/\tau)}

li,j​=−log∑k=12N​1[k=i]​exp(sin(zi​,zk​)/τ)exp(sin(zi​,zj​)/τ)​

NT-Xent 损失

    在较高层次上,对比学习模型的输入来自 N 个底层图像的 2N 个图像。N 个底层图像中的每一个都使用一组随机图像增强进行增强,以生成 2 个增强图像。这就是我们最终在输入模型的单个训练批次中获得 2N 个图像的方式。

PyTorch 中 NT-Xent 损失的实现

    网上看到的许多NT-Xent 丢失的实现都是从头开始实现所有操作。此外,他们中的一些人实现损失函数的效率很低,更喜欢使用for 循环而不是 GPU 并行性。相反,我们将使用不同的方法。我们将根据 PyTorch 已经提供的标准交叉熵损失来实现此损失。为此,我们需要以 cross_entropy 可以接受的格式处理预测和真实标签。下面让我们看看如何执行此操作。

预测张量:首先,我们需要创建一个 PyTorch 张量来表示对比学习模型的输出。假设我们的批量大小为 8 (一张图片进行两次变换,所以2N=8)。我们将输入变量称为“x”。然后对x进行升维操作,这里的具体操作可以看我的另一篇博客:PyTorch 中所有样本对的余弦相似度快速计算

x = torch.randn(8, 2)

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

    如上所述,我们需要忽略每个特征向量的自相似性得分,因为它对模型的学习没有贡献,并且当我们想要计算交叉熵损失时,它会成为不必要的麻烦。为此,我们将定义一个变量“eye”,它是一个矩阵,主对角线上的元素值为 1.0,其余元素值为 0.0。我们可以使用以下命令创建这样的矩阵。

eye = torch.eye(8)

eye = eye.bool()#将其转换为布尔矩阵,以便我们可以使用此掩码矩阵索引到“xcs”变量。

y = xcs.clone()#将张量“xcs”克隆到名为“y”的张量中,以便稍后可以引用“xcs”张量。

y[eye] = float("-inf")#沿所有对余弦相似度矩阵的主对角线的值设置为-inf,这样当我们计算每行的 softmax 时,该值将不会产生任何影响。e的负无穷次方为0

ground truth (target tensor):对于我们使用的示例(2N=8),真实标签的样子如下:

tensor([1,0,3,2,5,4,7,6])

    很难理解?这是因为张量“y”中的以下索引对包含正对。这里需要对F.cosine_similarity()函数做一定了解,他有两个重要参数,(input,target)一般用全连接层的输出做input,(注意:全连接层的输出形状为[batch_size,type_num]。含义是第i个样本为第j类的概率。),target表示对应的真实标签的下标索引,所以input[i][target[i]]表示第i个样本预测正确的概率。这里可以参考这篇文章:【pytorch】交叉熵损失函数 F.cross_entropy()

上面的target张量在计算过程中的作用可以用这张图表示:     表示取图中打勾的索引元素,打勾的代表彼此互为正样本,即:

(0, 1), (1, 0) (2, 3), (3, 2) (4, 5), (5, 4) (6, 7), (7, 6)

    为了创建上面的张量,我们可以使用以下 PyTorch 代码,它将ground truth标签存储在变量“target”中。

target = torch.arange(8)

target[0::2] += 1

target[1::2] -= 1

交叉熵损失:我们拥有计算损失所需的所有材料了!唯一要做的就是调用 PyTorch 中的 cross_entropy API。

loss = F.cross_entropy(y / temperature, target, reduction="mean")

整合以上代码:

def nt_xent_loss(x, temperature):

assert len(x.size()) == 2

# Cosine similarity

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

xcs[torch.eye(x.size(0)).bool()] = float("-inf")

# Ground truth labels

target = torch.arange(8)

target[0::2] += 1

target[1::2] -= 1

# Standard cross-entropy loss

return F.cross_entropy(xcs / temperature, target, reduction="mean")

以上文章主要翻译自:NT-Xent (Normalized Temperature-Scaled Cross-Entropy) Loss Explained and Implemented in PyTorch感兴趣的可以看看原文,绝对精彩!

好文链接

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: