quaterion.eval.samplers.base_sampler 模块¶
- class BaseSampler(sample_size=-1, device: device | str | None = None, log_progress: bool = True)[source]¶
基类:
object
采样部分嵌入和目标,以便对部分数据执行指标计算
采样器允许减少计算距离矩阵所需的时间和资源。它选择嵌入并计算矩形形状的矩阵,而不是计算形状为 (num_embeddings, num_embeddings) 的方形矩阵。
- 参数
sample_size: 要选择的对象的数量。
- sample(dataset: Sized, metric: BaseMetric, model: SimilarityModel) Tuple[Tensor, Tensor] [source]¶
采样对象和标签以计算指标
- 参数:
dataset – 带大小的对象,例如 list, tuple, torch.utils.data.Dataset 等,用于采样
metric – 用于计算最终标签表示的指标实例
model – 用于编码对象的模型
- 返回:
Tensor, Tensor – 指标标签和计算的距离矩阵