class OnlineContrastiveLoss(margin: float | None = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: str | None = 'hard')[来源]¶
-
基类:
GroupLoss
实现对比损失 (Contrastive Loss),定义见 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
与
ContrastiveLoss
不同,本类支持在线对挖掘(online pair mining),即它会在运行时(on-the-fly)构建正样本对和负样本对,因此您无需自行构建此类对。相反,它首先计算批次中的所有可能对,然后分别筛选出有效的正样本对和有效的负样本对。支持用于在线对挖掘的 batch-all 和 batch-hard 策略。参数:
- margin – 用于将负样本推开的裕量值 (margin value)。可选,默认为 0.5。
- 通过运行时(on-the-fly)构建对来计算对比损失。
embeddings – 形状: (batch_size, vector_length) - 嵌入向量批次
- 用于保存和加载配置。
配置对象必须是 JSON 可序列化的。
Dict[str, Any] – 参数的 JSON 可序列化字典
- torch.Tensor – 标量损失值。
training: bool¶