quaterion.loss.triplet_loss 模块¶
- class TripletLoss(margin: float | None = 0.5, distance_metric_name: Distance | None =Distance.COSINE, mining: str | None = 'hard', soft: bool | None = False)[source]¶
基类:
GroupLoss
实现了 Triplet Loss,定义见 https://arxiv.org/abs/1503.03832
支持 batch-all、batch-hard 和 batch-semihard 在线三元组挖掘策略。
- 参数:
margin – 用于推开负例的边距值。
distance_metric_name – 距离函数的名称,例如
Distance
。mining – 三元组挖掘策略。可选值包括 “all”, “hard”, “semi_hard”。
soft – 如果为 True,则使用 Hard Triplet Loss 的软边距变体。在所有其他情况下忽略。
- forward(embeddings: Tensor, groups: LongTensor) Tensor [source]¶
使用指定的 embedding 和标签计算 Triplet Loss。
- 参数:
embeddings – shape: (batch_size, vector_length) - 一批 embedding。
groups – shape: (batch_size,) - 与 embeddings 关联的一批标签
- 返回:
torch.Tensor – 标量损失值。
- xbm_loss(embeddings: Tensor, groups: LongTensor, memory_embeddings: Tensor, memory_groups: LongTensor) Tensor [source]¶
为该损失实现 XBM 损失计算。
- 参数:
embeddings – shape: (batch_size, vector_length) - 编码器的输出 embedding。
groups – shape: (batch_size,) - 与 embedding 关联的组 id。
memory_embeddings – shape: (memory_buffer_size, vector_length) - 存储在环形缓冲区中的 embedding
memory_groups – shape: (memory_buffer_size,) - 与 memory_embeddings 关联的组 id
- 返回:
Tensor – 零大小张量,XBM 损失值。
- training: bool¶