快捷方式

quaterion.loss.triplet_loss 模块

class TripletLoss(margin: float | None = 0.5, distance_metric_name: Distance | None =Distance.COSINE, mining: str | None = 'hard', soft: bool | None = False)[source]

基类:GroupLoss

实现了 Triplet Loss,定义见 https://arxiv.org/abs/1503.03832

支持 batch-all、batch-hard 和 batch-semihard 在线三元组挖掘策略。

参数
  • margin – 用于推开负例的边距值。

  • distance_metric_name – 距离函数的名称,例如 Distance

  • mining – 三元组挖掘策略。可选值包括 “all”, “hard”, “semi_hard”

  • soft – 如果为 True,则使用 Hard Triplet Loss 的软边距变体。在所有其他情况下忽略。

forward(embeddings: Tensor, groups: LongTensor) Tensor[source]

使用指定的 embedding 和标签计算 Triplet Loss。

参数
  • embeddings – shape: (batch_size, vector_length) - 一批 embedding。

  • groups – shape: (batch_size,) - 与 embeddings 关联的一批标签

返回

torch.Tensor – 标量损失值。

get_config_dict()[source]

用于保存和加载目的的配置。

配置对象必须是 JSON 可序列化的。

返回

Dict[str, Any] – JSON 可序列化的参数字典

xbm_loss(embeddings: Tensor, groups: LongTensor, memory_embeddings: Tensor, memory_groups: LongTensor) Tensor[source]

为该损失实现 XBM 损失计算。

参数
  • embeddings – shape: (batch_size, vector_length) - 编码器的输出 embedding。

  • groups – shape: (batch_size,) - 与 embedding 关联的组 id。

  • memory_embeddings – shape: (memory_buffer_size, vector_length) - 存储在环形缓冲区中的 embedding

  • memory_groups – shape: (memory_buffer_size,) - 与 memory_embeddings 关联的组 id

返回

Tensor – 零大小张量,XBM 损失值。

training: bool

Qdrant

了解有关 Qdrant 向量搜索项目和生态系统的更多信息

发现 Qdrant

相似度学习

探索使用相似度学习解决实际问题

学习相似度学习

领英 推特 Discord

寻找面临类似问题的人并获得问题的答案

加入社区