快捷方式

quaterion.dataset.similarity_data_loader 模块

class GroupSimilarityDataLoader(dataset: Dataset[SimilarityGroupSample], **kwargs)[source]

基类: SimilarityDataLoader[SimilarityGroupSample]

用于处理表示为 SimilarityGroupSample 数据的数据加载器。

classmethod collate_labels(batch: List[SimilarityGroupSample]) Dict[str, Tensor][source]

标签的整理函数

将标签转换为张量,适用于直接传递给损失函数和指标估计器。

参数:

batchSimilarityGroupSample 列表

返回:

整理后的标签

  • groups – 每个特征对象的组 ID

示例

>>> GroupSimilarityDataLoader.collate_labels(
...     [
...         SimilarityGroupSample(obj="orange", group=0),
...         SimilarityGroupSample(obj="lemon", group=0),
...         SimilarityGroupSample(obj="apple", group=1)
...     ]
... )
{'groups': tensor([0, 0, 1])}
classmethod flatten_objects(batch: List[SimilarityGroupSample], hash_ids: List[int]) Tuple[List[Any], List[int]][source]

从相似度样本中检索并枚举对象。

每个独立对象都应该用作编码器的输入。此外,将 hash_id 与每个特征关联起来,如果样本中存在多个特征 - 则根据输入生成新的唯一 ID。

参数:
  • batch – 相似度样本列表

  • hash_ids – 相似度样本的伪随机 ID

返回:

  • 用于编码器整理的输入特征列表

  • 与每个特征关联的 ID 列表

batch_size: int | None
dataset: Dataset[T_co]
drop_last: bool
num_workers: int
pin_memory: bool
pin_memory_device: str
prefetch_factor: int | None
sampler: Sampler | Iterable
timeout: float
class PairsSimilarityDataLoader(dataset: Dataset[SimilarityPairSample], **kwargs)[source]

基类: SimilarityDataLoader[SimilarityPairSample]

用于处理表示为 SimilarityPairSample 数据的数据加载器。

classmethod collate_labels(batch: List[SimilarityPairSample]) Dict[str, Tensor][source]

SimilarityPairSample 标签的整理函数

将标签转换为张量,适用于直接传递给损失函数和指标估计器。

参数:

batchSimilarityPairSample 列表

返回:

整理后的标签

  • labels - 每个输入对的得分张量

  • pairs - 与对应标签关联的特征的 ID 偏移量对

  • subgroups - 每个特征的子组 ID

示例

>>> labels_batch = PairsSimilarityDataLoader.collate_labels(
...     [
...         SimilarityPairSample(
...             obj_a="1st_pair_1st_obj", obj_b="1st_pair_2nd_obj", score=1.0, subgroup=0
...         ),
...         SimilarityPairSample(
...             obj_a="2nd_pair_1st_obj", obj_b="2nd_pair_2nd_obj", score=0.0, subgroup=1
...         ),
...     ]
... )
>>> labels_batch['labels']
tensor([1., 0.])
>>> labels_batch['pairs']
tensor([[0, 2],
        [1, 3]])
>>> labels_batch['subgroups']
tensor([0., 1., 0., 1.])
classmethod flatten_objects(batch: List[SimilarityPairSample], hash_ids: List[int]) Tuple[List[Any], List[int]][source]

从相似度样本中检索并枚举对象。

每个独立对象都应该用作编码器的输入。此外,将 hash_id 与每个特征关联起来,如果样本中存在多个特征 - 则根据输入生成新的唯一 ID。

参数:
  • batch – 相似度样本列表

  • hash_ids – 相似度样本的伪随机 ID

返回:

  • 用于编码器整理的输入特征列表

  • 与每个特征关联的 ID 列表

batch_size: int | None
dataset: Dataset[T_co]
drop_last: bool
num_workers: int
pin_memory: bool
pin_memory_device: str
prefetch_factor: int | None
sampler: Sampler | Iterable
timeout: float
class SimilarityDataLoader(dataset: Dataset, **kwargs)[source]

基类: DataLoader, Generic[T_co]

DataLoader 的特殊版本,用于处理相似度样本。

SimilarityDataLoader 将自动为调试目的分配一个虚拟 collate_fn,一旦数据加载器用于训练,它将被覆盖。

必需的 collate 函数应通过覆盖 get_collate_fn() 为每个编码器单独定义。

参数:
  • dataset – 输出相似度样本的数据集

  • **kwargs – 直接传递给 __init__() 的参数

classmethod collate_labels(batch: List[T_co]) Dict[str, Tensor][source]

标签的整理函数

将标签转换为张量,适用于直接传递给损失函数和指标估计器。

参数:

batch – 相似度样本列表

返回:

整理后的标签

classmethod flatten_objects(batch: List[T_co], hash_ids: List[int]) Tuple[List[Any], List[int]][source]

从相似度样本中检索并枚举对象。

每个独立对象都应该用作编码器的输入。此外,将 hash_id 与每个特征关联起来,如果样本中存在多个特征 - 则根据输入生成新的唯一 ID。

参数:
  • batch – 相似度样本列表

  • hash_ids – 相似度样本的伪随机 ID

返回:

  • 用于编码器整理的输入特征列表

  • 与每个特征关联的 ID 列表

load_label_cache(path: str)[source]
classmethod pre_collate_fn(batch: List[T_co])[source]

在实际整理之前应用于批次数据的函数。

将批次数据分割为特征(预测的参数)和标签(目标)。然后,编码器特定的 collate_fn 将仅应用于特征列表。损失函数直接使用此函数提供的标签,无需任何额外转换。

参数:

batch – 相似度样本列表

返回:

  • 特征的 ID

  • 特征批次

  • 标签批次

save_label_cache(path: str)[source]
set_label_cache_mode(mode: LabelCacheMode)[source]

管理标签缓存的工作方式

set_salt(salt)[source]

为 IndexingDataset 分配一个新的盐值 (salt)。这有助于区分训练和验证数据集的缓存序列键。

参数:

salt – 用于索引生成的盐值 (salt)

set_skip_read(skip: bool)[source]

禁用 IndexingDataset 中读取项。如果缓存已满且使用了序列键 - 则无需第二次读取数据集项。

参数:

skip – 如果为 True - 不读取项,仅读取索引

batch_size: int | None
dataset: Dataset[T_co]
drop_last: bool
property full_cache_used
num_workers: int
property original_params: Dict[str, Any]

原始数据集的初始化参数。

pin_memory: bool
pin_memory_device: str
prefetch_factor: int | None
sampler: Sampler | Iterable
timeout: float

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统

发现 Qdrant

相似度学习

探索使用相似度学习解决实际问题

学习相似度学习

社区

寻找遇到类似问题的人并获得问题解答

加入社区