quaterion.eval.samplers.group_sampler 模块¶
- class GroupSampler(sample_size=-1, encode_batch_size=16, device: device | str | None = None, log_progress: bool = True)[源代码]¶
基类:
BaseSampler
为基于分组的任务执行嵌入和目标的选取。
- accumulate(model: SimilarityModel, dataset: Sized | Iterable | Dataset)[源代码]¶
编码对象并使用相应的原始标签累积嵌入
- 参数:
model – 用于编码模型的模型
dataset – 带大小的对象,例如列表、元组、torch.utils.data.Dataset 等,用于累积
- sample(dataset: Sized, metric: GroupMetric, model: SimilarityModel) Tuple[Tensor, Tensor] [源代码]¶
为基于分组的任务采样嵌入和目标。
- 参数:
dataset – 带大小的对象,例如列表、元组、torch.utils.data.Dataset 等,用于采样
metric – GroupMetric 实例,用于计算最终标签表示
model – 用于编码模型的模型
- 返回:
torch.Tensor, torch.Tensor – 度量标签和计算的距离矩阵