quaterion.loss.online_contrastive_loss 模块

class OnlineContrastiveLoss(margin: float | None = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: str | None = 'hard')[来源]

基类: GroupLoss

实现对比损失 (Contrastive Loss),定义见 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

ContrastiveLoss 不同,本类支持在线对挖掘(online pair mining),即它会在运行时(on-the-fly)构建正样本对和负样本对,因此您无需自行构建此类对。相反,它首先计算批次中的所有可能对,然后分别筛选出有效的正样本对和有效的负样本对。支持用于在线对挖掘的 batch-all 和 batch-hard 策略。

参数:

margin – 用于将负样本推开的裕量值 (margin value)。可选,默认为 0.5
  • distance_metric_name – 距离函数的名称,例如 Distance。可选,默认为 COSINE

  • mining (str, optional) – 对挖掘策略。取值可为 “all”“hard”。默认为 “hard”

  • forward(embeddings: Tensor, groups: LongTensor) Tensor[来源]

通过运行时(on-the-fly)构建对来计算对比损失。

embeddings – 形状: (batch_size, vector_length) - 嵌入向量批次

margin – 用于将负样本推开的裕量值 (margin value)。可选,默认为 0.5
  • groups – 形状 (batch_size,) 与 embeddings 相关的标签批次

  • 返回值:

torch.Tensor – 标量损失值。

get_config_dict()[来源]

用于保存和加载配置。

配置对象必须是 JSON 可序列化的。

Dict[str, Any] – 参数的 JSON 可序列化字典

torch.Tensor – 标量损失值。

training: bool


使用 Sphinx 构建,主题由 主题 提供,提供者为 Read the Docs

OnlineContrastiveLoss

GitHub

发现 Qdrant

相似性学习

探索如何利用相似性学习解决实际问题

学习相似性学习

找到面临类似问题的人,并获得解答