quaterion.loss.group_loss 模块¶
- 类 GroupLoss(distance_metric_name: Distance = Distance.COSINE)[source]¶
基类:
SimilarityLoss
组损失的基类。
- 参数:
distance_metric_name – 距离函数的名称,例如,
Distance
。
- forward(embeddings: Tensor, groups: LongTensor) Tensor [source]¶
- 参数:
embeddings – 形状: (batch_size, vector_length)
groups – 形状: (batch_size,) - 与 embeddings 关联的组
- 返回:
Tensor – 零大小张量,损失值
- xbm_loss(embeddings: Tensor, groups: LongTensor, memory_embeddings: Tensor, memory_groups: LongTensor) Tensor [source]¶
为此损失实现 XBM 损失计算。
- 参数:
embeddings – 形状: (batch_size, vector_length) - 来自编码器的输出嵌入。
groups – 形状: (batch_size,) - 与嵌入关联的组 ID。
memory_embeddings – 形状: (memory_buffer_size, vector_length) - 存储在环形缓冲区中的嵌入
memory_groups – 形状: (memory_buffer_size,) - 与 memory_embeddings 关联的组 ID
- 返回:
Tensor – 零大小张量,XBM 损失值。
- training: bool¶