快捷方式

quaterion.eval.accumulators.pair_accumulator 模块

class PairAccumulator[source]

基类: Accumulator

为基于对的任务累积嵌入、标签、对和子组。

跟踪当前大小以正确处理对。

reset()[source]

重置累加器状态

重置累加器状态和大小,累积的嵌入、标签、对和子组

update(embeddings: Tensor, labels: Tensor, pairs: LongTensor, subgroups: Tensor, device=None)[source]

更新累加器状态。

将提供的嵌入和组移动到适当的设备并添加到累积状态。

参数:
  • embeddings – 要累积的嵌入

  • labels – 用于区分相似和不相似对象的标签。

  • pairs – 用于确定一对对象的索引

  • subgroups – 子组编号,用于确定哪些样本可以视为负样本

  • device – 存储计算出的嵌入和组的设备。

property labels

将标签列表连接到 Tensor

有助于避免在累积期间为每个批次连接标签。相反,仅在调用时连接它。

返回:

torch.Tensor – 标签批次

property pairs: LongTensor

将对列表连接到 Tensor

有助于避免在累积期间为每个批次连接对。相反,仅在调用时连接它。

返回:

torch.Tensor – 对批次

property state: Dict[str, Tensor]

累积状态

返回:

Dict[str, torch.Tensor] - 字典,累积嵌入、标签、对、子组。

property subgroups

将子组列表连接到 Tensor

有助于避免在累积期间为每个批次连接子组。相反,仅在调用时连接它。

返回:

torch.Tensor – 子组批次

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统

发现 Qdrant

相似性学习

探索如何使用相似性学习解决实际问题

学习相似性学习

社区

寻找遇到相似问题的人,并获得解答

加入社区