快捷方式

quaterion.loss.cos_face_loss 模块

class CosFaceLoss(embedding_size: int, num_groups: int, margin: float | None = 0.35, scale: float | None = 64.0)[源代码]

基类: GroupLoss

大间隔余弦损失,定义见 https://arxiv.org/pdf/1801.09414.pdf

参数:
  • embedding_size – 编码器的输出维度。

  • num_groups – 数据集中的分组数量。

  • scale – 用于使交叉熵生效的缩放值。

  • margin – 用于推开分组的间隔值。

forward(embeddings: Tensor, groups: LongTensor) Tensor[源代码]

计算损失值 :param embeddings: shape: (batch_size, vector_length) - 来自

编码器的输出嵌入。

参数:

groups – shape: (batch_size,) - 与嵌入相关的分组 ID。

返回:

Tensor – 损失值。

training: bool

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统的信息

发现 Qdrant

相似性学习

探索使用相似性学习解决实际问题

学习相似性学习

社区

找到面临类似问题的人并获得解答

加入社区