quaterion.dataset.similarity_data_loader 模块¶
- class GroupSimilarityDataLoader(dataset: Dataset[SimilarityGroupSample], **kwargs)[source]¶
基类:
SimilarityDataLoader
[SimilarityGroupSample
]用于处理表示为
SimilarityGroupSample
数据的数据加载器。- classmethod collate_labels(batch: List[SimilarityGroupSample]) Dict[str, Tensor] [source]¶
标签的整理函数
将标签转换为张量,适用于直接传递给损失函数和指标估计器。
- 参数:
batch –
SimilarityGroupSample
列表- 返回:
整理后的标签 –
groups – 每个特征对象的组 ID
示例
>>> GroupSimilarityDataLoader.collate_labels( ... [ ... SimilarityGroupSample(obj="orange", group=0), ... SimilarityGroupSample(obj="lemon", group=0), ... SimilarityGroupSample(obj="apple", group=1) ... ] ... ) {'groups': tensor([0, 0, 1])}
- classmethod flatten_objects(batch: List[SimilarityGroupSample], hash_ids: List[int]) Tuple[List[Any], List[int]] [source]¶
从相似度样本中检索并枚举对象。
每个独立对象都应该用作编码器的输入。此外,将 hash_id 与每个特征关联起来,如果样本中存在多个特征 - 则根据输入生成新的唯一 ID。
- 参数:
batch – 相似度样本列表
hash_ids – 相似度样本的伪随机 ID
- 返回:
用于编码器整理的输入特征列表
与每个特征关联的 ID 列表
- batch_size: int | None¶
- drop_last: bool¶
- num_workers: int¶
- pin_memory: bool¶
- pin_memory_device: str¶
- prefetch_factor: int | None¶
- timeout: float¶
- class PairsSimilarityDataLoader(dataset: Dataset[SimilarityPairSample], **kwargs)[source]¶
基类:
SimilarityDataLoader
[SimilarityPairSample
]用于处理表示为
SimilarityPairSample
数据的数据加载器。- classmethod collate_labels(batch: List[SimilarityPairSample]) Dict[str, Tensor] [source]¶
SimilarityPairSample
标签的整理函数将标签转换为张量,适用于直接传递给损失函数和指标估计器。
- 参数:
batch –
SimilarityPairSample
列表- 返回:
整理后的标签 –
labels - 每个输入对的得分张量
pairs - 与对应标签关联的特征的 ID 偏移量对
subgroups - 每个特征的子组 ID
示例
>>> labels_batch = PairsSimilarityDataLoader.collate_labels( ... [ ... SimilarityPairSample( ... obj_a="1st_pair_1st_obj", obj_b="1st_pair_2nd_obj", score=1.0, subgroup=0 ... ), ... SimilarityPairSample( ... obj_a="2nd_pair_1st_obj", obj_b="2nd_pair_2nd_obj", score=0.0, subgroup=1 ... ), ... ] ... ) >>> labels_batch['labels'] tensor([1., 0.]) >>> labels_batch['pairs'] tensor([[0, 2], [1, 3]]) >>> labels_batch['subgroups'] tensor([0., 1., 0., 1.])
- classmethod flatten_objects(batch: List[SimilarityPairSample], hash_ids: List[int]) Tuple[List[Any], List[int]] [source]¶
从相似度样本中检索并枚举对象。
每个独立对象都应该用作编码器的输入。此外,将 hash_id 与每个特征关联起来,如果样本中存在多个特征 - 则根据输入生成新的唯一 ID。
- 参数:
batch – 相似度样本列表
hash_ids – 相似度样本的伪随机 ID
- 返回:
用于编码器整理的输入特征列表
与每个特征关联的 ID 列表
- batch_size: int | None¶
- drop_last: bool¶
- num_workers: int¶
- pin_memory: bool¶
- pin_memory_device: str¶
- prefetch_factor: int | None¶
- timeout: float¶
- class SimilarityDataLoader(dataset: Dataset, **kwargs)[source]¶
基类:
DataLoader
,Generic
[T_co
]DataLoader
的特殊版本,用于处理相似度样本。SimilarityDataLoader 将自动为调试目的分配一个虚拟 collate_fn,一旦数据加载器用于训练,它将被覆盖。
必需的 collate 函数应通过覆盖
get_collate_fn()
为每个编码器单独定义。- 参数:
dataset – 输出相似度样本的数据集
**kwargs – 直接传递给
__init__()
的参数
- classmethod collate_labels(batch: List[T_co]) Dict[str, Tensor] [source]¶
标签的整理函数
将标签转换为张量,适用于直接传递给损失函数和指标估计器。
- 参数:
batch – 相似度样本列表
- 返回:
整理后的标签
- classmethod flatten_objects(batch: List[T_co], hash_ids: List[int]) Tuple[List[Any], List[int]] [source]¶
从相似度样本中检索并枚举对象。
每个独立对象都应该用作编码器的输入。此外,将 hash_id 与每个特征关联起来,如果样本中存在多个特征 - 则根据输入生成新的唯一 ID。
- 参数:
batch – 相似度样本列表
hash_ids – 相似度样本的伪随机 ID
- 返回:
用于编码器整理的输入特征列表
与每个特征关联的 ID 列表
- classmethod pre_collate_fn(batch: List[T_co])[source]¶
在实际整理之前应用于批次数据的函数。
将批次数据分割为特征(预测的参数)和标签(目标)。然后,编码器特定的 collate_fn 将仅应用于特征列表。损失函数直接使用此函数提供的标签,无需任何额外转换。
- 参数:
batch – 相似度样本列表
- 返回:
特征的 ID
特征批次
标签批次
- set_label_cache_mode(mode: LabelCacheMode)[source]¶
管理标签缓存的工作方式
- set_salt(salt)[source]¶
为 IndexingDataset 分配一个新的盐值 (salt)。这有助于区分训练和验证数据集的缓存序列键。
- 参数:
salt – 用于索引生成的盐值 (salt)
- set_skip_read(skip: bool)[source]¶
禁用 IndexingDataset 中读取项。如果缓存已满且使用了序列键 - 则无需第二次读取数据集项。
- 参数:
skip – 如果为 True - 不读取项,仅读取索引
- batch_size: int | None¶
- drop_last: bool¶
- property full_cache_used¶
- num_workers: int¶
- property original_params: Dict[str, Any]¶
原始数据集的初始化参数。
- pin_memory: bool¶
- pin_memory_device: str¶
- prefetch_factor: int | None¶
- timeout: float¶