quaterion.loss.cos_face_loss 模块¶
- class CosFaceLoss(embedding_size: int, num_groups: int, margin: float | None = 0.35, scale: float | None = 64.0)[源代码]¶
基类:
GroupLoss
大间隔余弦损失,定义见 https://arxiv.org/pdf/1801.09414.pdf
- 参数:
embedding_size – 编码器的输出维度。
num_groups – 数据集中的分组数量。
scale – 用于使交叉熵生效的缩放值。
margin – 用于推开分组的间隔值。
- forward(embeddings: Tensor, groups: LongTensor) Tensor [源代码]¶
计算损失值 :param embeddings: shape: (batch_size, vector_length) - 来自
编码器的输出嵌入。
- 参数:
groups – shape: (batch_size,) - 与嵌入相关的分组 ID。
- 返回:
Tensor – 损失值。
- training: bool¶