快捷方式

quaterion.loss.multiple_negatives_ranking_loss 模块

class MultipleNegativesRankingLoss(scale: float = 20.0, distance_metric_name: Distance = Distance.COSINE, symmetric: bool = False)[source]

基类: PairwiseLoss

实现 Multiple Negatives Ranking Loss,如 https://arxiv.org/pdf/1705.00652.pdf 中所述

这个损失函数只适用于正样本对,例如,一个 anchor 和一个 positive。对于每一对,它使用批处理中其他对的正样本作为负样本,因此您无需担心指定负样本。它非常适合检索任务,例如问答检索、重复句子检索和跨模态检索。它接受 anchor 和 positive 嵌入对,以计算它们之间的相似度矩阵。然后,它最小化 softmax 归一化相似度得分的负对数似然。这优化了在给定 anchor 时检索正确正样本对的能力。

注意

此损失函数会忽略 scoresubgroup 值,假定 obj_aobj_b 构成一个正样本对,例如 label = 1

参数:
  • scale – 用于乘以相似度得分以使交叉熵起作用的缩放值。

  • distance_metric_name – 用于计算嵌入之间相似度的度量名称,例如 Distance。可选,默认为 COSINE。如果为 DOT_PRODUCT,则 scale 必须为 1

  • symmetric – 如果为 True,则损失是对称的,即它在给定 positive 时也会考虑检索到正确的 anchor。

forward(embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs) Tensor[source]

计算损失值。

参数:
  • embeddings – 嵌入批次,嵌入的前半部分是样本对中第一个对象的嵌入,后半部分是样本对中第二个对象的嵌入。

  • pairs – 样本对中对应对象的索引。

  • labels – 此损失函数会忽略此参数。标签将自动从 pairs 中生成。

  • subgroups – 此损失函数会忽略此参数。

  • **kwargs – 用于损失调用泛化的附加关键字参数

返回值:

Tensor – 标量损失值

get_config_dict() Dict[str, Any][source]

用于保存和加载目的的配置。

配置对象必须是 JSON 可序列化的。

返回值:

Dict[str, Any] – JSON 可序列化参数字典

training: bool

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统的信息

探索 Qdrant

相似性学习

探索使用相似性学习解决实际问题

学习相似性学习

社区

找到面临类似问题的人并获得您问题的答案

加入社区