快捷方式

quaterion.loss.contrastive_loss 模块

class ContrastiveLoss(distance_metric_name: Distance = Distance.COSINE, margin: float = 0.5, size_average: bool = True)[source]

Bases: PairwiseLoss

对比损失。

输入预期为两个文本和一个标签,标签值为 0 或 1。如果标签 == 1,则两个嵌入之间的距离减小。如果标签 == 0,则嵌入之间的距离增加。

更多信息

http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

参数:
  • distance_metric_name – 函数名称,例如,Distance。可选,默认为 COSINE

  • margin – 负样本(标签 == 0)之间的距离应至少等于间隔值。

  • size_average – 按 mini-batch 的大小平均。

forward(embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs) Tensor[source]

计算损失值。

参数:
  • embeddings – 嵌入批次,嵌入的前半部分是配对中第一个对象的嵌入,后半部分是配对中第二个对象的嵌入。

  • pairs – 配对中对应对象的索引。

  • labels – 正样本和负样本的得分。

  • subgroups – 用于区分可作为负样本和不可作为负样本的子组

  • **kwargs – 损失调用泛化的附加关键字参数

返回:

Tensor – 平均或求和后的损失值

get_config_dict() Dict[str, Any][source]

用于保存和加载的配置。

配置对象必须是 JSON 可序列化的。

返回:

Dict[str, Any] – JSON 可序列化的参数字典

training: bool

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统

探索 Qdrant

相似性学习

探索如何使用相似性学习解决实际问题

学习相似性学习

社区

找到与您遇到类似问题的人,并获得解答

加入社区