快捷方式

quaterion.eval.samplers.pair_sampler 模块

class PairSampler(sample_size: int = -1, distinguish: bool = False, encode_batch_size: int = 16, device: device | str | None = None, log_progress: bool = True)[source]

基类:BaseSampler

为基于对的任务执行 embeddings 和 targets 的选择。

Sampler 允许减少计算距离矩阵所需的时间和资源。它不计算形状为 (num_embeddings, num_embeddings) 的平方矩阵,而是选择 embeddings 并计算矩形形状的矩阵。

参数:
  • sample_size – int - 要选择的对象数量

  • distinguish – bool - 确定是比较所有对象之间的两两关系,还是仅比较 obj_aobj_b。如果为 true,则仅比较 obj_aobj_b。显著减少矩阵大小。

  • encode_batch_size – int - 编码期间使用的批次大小

accumulate(model: SimilarityModel, dataset: Sized)[source]

编码对象并使用相应的原始标签积累 embeddings

参数:
  • model – 用于编码对象的模型

  • dataset – 带尺寸的对象,例如 list, tuple, torch.utils.data.Dataset 等,用于积累

reset()[source]

重置积累的状态

sample(dataset: Sized, metric: PairMetric, model: SimilarityModel) Tuple[Tensor, Tensor][source]

为基于对的任务采样 embeddings 和 targets。

参数:
  • dataset – 带尺寸的对象,例如 list, tuple, torch.utils.data.Dataset 等,用于采样

  • metric – 用于计算最终标签表示的 PairMetric 实例

  • model – 用于编码对象的模型

返回:

torch.Tensor, torch.Tensor – 指标标签和计算出的距离矩阵

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统的信息

探索 Qdrant

相似度学习

探索使用相似度学习解决实际问题

学习相似度学习

领英 推特 Discord

寻找遇到类似问题的人,并获得您问题的答案

加入社区