快捷方式

quaterion.loss.center_loss 模块

class CenterLoss(embedding_size: int, num_groups: int, lambda_c: float | None = 0.5)[源代码]

基类: GroupLoss

中心损失,如论文“A Discriminative Feature Learning Approach for Deep Face Recognition”(http://ydwen.github.io/papers/WenECCV16.pdf)中所定义。它旨在最小化类内差异,同时保持不同类别特征的可分离性。

参数:
  • embedding_size – 编码器的输出维度。

  • num_groups – 数据集中的组(类别)数量。

  • lambda_c – 一个控制中心损失贡献的正则化参数。

forward(embeddings: Tensor, groups: LongTensor) Tensor[源代码]

计算中心损失值。

参数:
  • embeddings – 形状 (batch_size, vector_length) - 来自编码器的输出嵌入。

  • groups – 形状 (batch_size,) - 与嵌入关联的组(类别)ID。

返回:

Tensor – 损失值。

training: bool

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统的信息

探索 Qdrant

相似性学习

探索使用相似性学习解决实际问题

学习相似性学习

社区

寻找处理类似问题的人并获得问题解答

加入社区