quaterion.loss.arcface_loss 模块¶
- class ArcFaceLoss(embedding_size: int, num_groups: int, scale: float = 64.0, margin: float = 0.5)[source]¶
Bases:
GroupLoss
累加角边距损失,定义于 https://arxiv.org/abs/1801.07698
- 参数:
embedding_size – 编码器的输出维度。
num_groups – 数据集中的组数。
scale – 用于使交叉熵生效的缩放值。
margin – 用于将组推开的边距值。
- forward(embeddings: Tensor, groups: LongTensor) Tensor [source]¶
计算损失值
- 参数:
embeddings – 形状: (batch_size, vector_length) - 来自编码器的输出嵌入。
groups – 形状: (batch_size,) - 与嵌入相关的组 ID。
- 返回:
Tensor – 损失值。
- training: bool¶