快捷方式

quaterion.main module

class Quaterion[source]

Bases: object

微调入口点

包含用于启动实际训练和评估过程的方法。

classmethod evaluate(evaluator: Evaluator, dataset: Sized | Iterable | Dataset, model: SimilarityModel) Dict[str, Tensor][source]

计算数据集上的指标

参数:
  • evaluator – 包含要使用的指标配置以及如何获取样本的对象

  • dataset – 可度量大小的对象,如 list, tuple, torch.utils.data.Dataset 等,用于计算指标

  • model – SimilarityModel 实例,用于执行对象编码

返回:

Dict[str, torch.Tensor] - 计算出的指标字典。其中 key - 指标名称,value - 指标估计值

classmethod fit(trainable_model: TrainableModel, trainer: Trainer | None, train_dataloader: SimilarityDataLoader, val_dataloader: SimilarityDataLoader | None = None, ckpt_path: str | None = None)[source]

处理训练流程

组合数据加载器,执行缓存和整个训练过程。

参数:
  • trainable_model – 要拟合的模型

  • trainerpytorch_lightning.Trainer 实例,用于内部处理拟合流程。如果传入 None,将使用 Quaterion.trainer_defaults() 创建 trainer。默认参数旨在作为模型学习的快速入门,我们鼓励用户在默认参数不能达到满意结果时尝试不同的参数。

  • train_dataloader – DataLoader 实例,用于在训练阶段检索样本

  • val_dataloader – 可选的 DataLoader 实例,用于在验证阶段检索样本

  • ckpt_path – 要从中恢复训练的检查点的路径/URL。如果路径下没有检查点文件,则抛出异常。如果从 epoch 中间检查点恢复,训练将从下一个 epoch 的开始处开始。

static trainer_defaults(trainable_model: TrainableModel | None = None, train_dataloader: SimilarityDataLoader | None = None)[source]

pytorch_lightning.Trainer 的合理默认参数

此函数为 Trainer 生成一组参数,这些参数被认为是 Quaterion 大多数用例的“推荐”参数。Quaterion 相似度学习训练过程具有与常规深度学习模型训练不同的特性。如果您需要针对特殊任务的特殊行为,可以覆盖这些默认参数。

如果您需要调整 Trainer 参数,请考虑覆盖默认参数

示例

trainer_kwargs = Quaterion.trainer_defaults(
    trainable_model=model,
    train_dataloader=train_dataloader
)
trainer_kwargs['logger'] = pl.loggers.WandbLogger(
    name="example_model",
    project="example_project",
)
trainer_kwargs['callbacks'].append(YourCustomCallback())
trainer = pl.Trainer(**trainer_kwargs)
参数:
  • trainable_model – 如果提供,我们将尝试根据模型配置调整默认参数

  • train_dataloader – 如果提供,Trainer 参数将根据数据集进行调整

返回:

pytorch_lightning.Trainer 的 kwargs

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统的信息

探索 Qdrant

相似度学习

探索使用相似度学习解决实际问题

学习相似度学习

社区

找到与您遇到类似问题的人并获得问题解答

加入社区