quaterion.main module¶
- class Quaterion[source]¶
Bases:
object
微调入口点
包含用于启动实际训练和评估过程的方法。
- classmethod evaluate(evaluator: Evaluator, dataset: Sized | Iterable | Dataset, model: SimilarityModel) Dict[str, Tensor] [source]¶
计算数据集上的指标
- 参数:
evaluator – 包含要使用的指标配置以及如何获取样本的对象
dataset – 可度量大小的对象,如 list, tuple, torch.utils.data.Dataset 等,用于计算指标
model – SimilarityModel 实例,用于执行对象编码
- 返回:
Dict[str, torch.Tensor] - 计算出的指标字典。其中 key - 指标名称,value - 指标估计值
- classmethod fit(trainable_model: TrainableModel, trainer: Trainer | None, train_dataloader: SimilarityDataLoader, val_dataloader: SimilarityDataLoader | None = None, ckpt_path: str | None = None)[source]¶
处理训练流程
组合数据加载器,执行缓存和整个训练过程。
- 参数:
trainable_model – 要拟合的模型
trainer – pytorch_lightning.Trainer 实例,用于内部处理拟合流程。如果传入 None,将使用
Quaterion.trainer_defaults()
创建 trainer。默认参数旨在作为模型学习的快速入门,我们鼓励用户在默认参数不能达到满意结果时尝试不同的参数。train_dataloader – DataLoader 实例,用于在训练阶段检索样本
val_dataloader – 可选的 DataLoader 实例,用于在验证阶段检索样本
ckpt_path – 要从中恢复训练的检查点的路径/URL。如果路径下没有检查点文件,则抛出异常。如果从 epoch 中间检查点恢复,训练将从下一个 epoch 的开始处开始。
- static trainer_defaults(trainable_model: TrainableModel | None = None, train_dataloader: SimilarityDataLoader | None = None)[source]¶
pytorch_lightning.Trainer 的合理默认参数
此函数为 Trainer 生成一组参数,这些参数被认为是 Quaterion 大多数用例的“推荐”参数。Quaterion 相似度学习训练过程具有与常规深度学习模型训练不同的特性。如果您需要针对特殊任务的特殊行为,可以覆盖这些默认参数。
如果您需要调整 Trainer 参数,请考虑覆盖默认参数
示例
trainer_kwargs = Quaterion.trainer_defaults( trainable_model=model, train_dataloader=train_dataloader ) trainer_kwargs['logger'] = pl.loggers.WandbLogger( name="example_model", project="example_project", ) trainer_kwargs['callbacks'].append(YourCustomCallback()) trainer = pl.Trainer(**trainer_kwargs)
- 参数:
trainable_model – 如果提供,我们将尝试根据模型配置调整默认参数
train_dataloader – 如果提供,Trainer 参数将根据数据集进行调整
- 返回:
pytorch_lightning.Trainer 的 kwargs