quaterion.loss.center_loss 模块¶
- class CenterLoss(embedding_size: int, num_groups: int, lambda_c: float | None = 0.5)[源代码]¶
基类:
GroupLoss
中心损失,如论文“A Discriminative Feature Learning Approach for Deep Face Recognition”(http://ydwen.github.io/papers/WenECCV16.pdf)中所定义。它旨在最小化类内差异,同时保持不同类别特征的可分离性。
- 参数:
embedding_size – 编码器的输出维度。
num_groups – 数据集中的组(类别)数量。
lambda_c – 一个控制中心损失贡献的正则化参数。
- forward(embeddings: Tensor, groups: LongTensor) Tensor [源代码]¶
计算中心损失值。
- 参数:
embeddings – 形状 (batch_size, vector_length) - 来自编码器的输出嵌入。
groups – 形状 (batch_size,) - 与嵌入关联的组(类别)ID。
- 返回:
Tensor – 损失值。
- training: bool¶