quaterion.loss.multiple_negatives_ranking_loss 模块¶
- class MultipleNegativesRankingLoss(scale: float = 20.0, distance_metric_name: Distance = Distance.COSINE, symmetric: bool = False)[source]¶
基类:
PairwiseLoss
实现 Multiple Negatives Ranking Loss,如 https://arxiv.org/pdf/1705.00652.pdf 中所述
这个损失函数只适用于正样本对,例如,一个 anchor 和一个 positive。对于每一对,它使用批处理中其他对的正样本作为负样本,因此您无需担心指定负样本。它非常适合检索任务,例如问答检索、重复句子检索和跨模态检索。它接受 anchor 和 positive 嵌入对,以计算它们之间的相似度矩阵。然后,它最小化 softmax 归一化相似度得分的负对数似然。这优化了在给定 anchor 时检索正确正样本对的能力。
- 参数:
scale – 用于乘以相似度得分以使交叉熵起作用的缩放值。
distance_metric_name – 用于计算嵌入之间相似度的度量名称,例如
Distance
。可选,默认为COSINE
。如果为DOT_PRODUCT
,则 scale 必须为 1。symmetric – 如果为 True,则损失是对称的,即它在给定 positive 时也会考虑检索到正确的 anchor。
- forward(embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs) Tensor [source]¶
计算损失值。
- 参数:
embeddings – 嵌入批次,嵌入的前半部分是样本对中第一个对象的嵌入,后半部分是样本对中第二个对象的嵌入。
pairs – 样本对中对应对象的索引。
labels – 此损失函数会忽略此参数。标签将自动从 pairs 中生成。
subgroups – 此损失函数会忽略此参数。
**kwargs – 用于损失调用泛化的附加关键字参数
- 返回值:
Tensor – 标量损失值
- get_config_dict() Dict[str, Any] [source]¶
用于保存和加载目的的配置。
配置对象必须是 JSON 可序列化的。
- 返回值:
Dict[str, Any] – JSON 可序列化参数字典
- training: bool¶