quaterion.eval.accumulators.pair_accumulator 模块¶
- class PairAccumulator[source]¶
基类:
Accumulator
为基于对的任务累积嵌入、标签、对和子组。
跟踪当前大小以正确处理对。
- update(embeddings: Tensor, labels: Tensor, pairs: LongTensor, subgroups: Tensor, device=None)[source]¶
更新累加器状态。
将提供的嵌入和组移动到适当的设备并添加到累积状态。
- 参数:
embeddings – 要累积的嵌入
labels – 用于区分相似和不相似对象的标签。
pairs – 用于确定一对对象的索引
subgroups – 子组编号,用于确定哪些样本可以视为负样本
device – 存储计算出的嵌入和组的设备。
- property labels¶
将标签列表连接到 Tensor
有助于避免在累积期间为每个批次连接标签。相反,仅在调用时连接它。
- 返回:
torch.Tensor – 标签批次
- property pairs: LongTensor¶
将对列表连接到 Tensor
有助于避免在累积期间为每个批次连接对。相反,仅在调用时连接它。
- 返回:
torch.Tensor – 对批次
- property subgroups¶
将子组列表连接到 Tensor
有助于避免在累积期间为每个批次连接子组。相反,仅在调用时连接它。
- 返回:
torch.Tensor – 子组批次