quaterion.dataset.train_collator 模块¶
- class TrainCollator(pre_collate_fn: Callable, encoder_collates: Dict[str, CollateFnType], meta_extractors: Dict[str, Callable[[List[Any]], List[dict]]])[source]¶
基类:
object
功能对象,聚合执行训练批次 collate 所需的所有信息。
注意
应可序列化以便在工作进程之间发送。
- 参数:
pre_collate_fn – 将原始批次分割为 ids、features 和 labels 的函数。ids 用于跟踪同一元素的重复使用。features 通常是编码器的输入。labels 通常用于区分正样本和负样本。
encoder_collates – 编码器名称与其 collate 函数的映射