quaterion.eval.samplers.pair_sampler 模块¶
- class PairSampler(sample_size: int = -1, distinguish: bool = False, encode_batch_size: int = 16, device: device | str | None = None, log_progress: bool = True)[source]¶
基类:
BaseSampler
为基于对的任务执行 embeddings 和 targets 的选择。
Sampler 允许减少计算距离矩阵所需的时间和资源。它不计算形状为 (num_embeddings, num_embeddings) 的平方矩阵,而是选择 embeddings 并计算矩形形状的矩阵。
- 参数:
sample_size – int - 要选择的对象数量
distinguish – bool - 确定是比较所有对象之间的两两关系,还是仅比较 obj_a 与 obj_b。如果为 true,则仅比较 obj_a 与 obj_b。显著减少矩阵大小。
encode_batch_size – int - 编码期间使用的批次大小
- accumulate(model: SimilarityModel, dataset: Sized)[source]¶
编码对象并使用相应的原始标签积累 embeddings
- 参数:
model – 用于编码对象的模型
dataset – 带尺寸的对象,例如 list, tuple, torch.utils.data.Dataset 等,用于积累
- sample(dataset: Sized, metric: PairMetric, model: SimilarityModel) Tuple[Tensor, Tensor] [source]¶
为基于对的任务采样 embeddings 和 targets。
- 参数:
dataset – 带尺寸的对象,例如 list, tuple, torch.utils.data.Dataset 等,用于采样
metric – 用于计算最终标签表示的 PairMetric 实例
model – 用于编码对象的模型
- 返回:
torch.Tensor, torch.Tensor – 指标标签和计算出的距离矩阵