• 文档 >
  • Quaterion 快速入门
快捷方式

Quaterion 快速入门

Quaterion 构建在 PyTorch Lightning 之上 - 这是一个用于高性能 AI 研究的框架。它负责构建 ML 模型训练循环中的所有任务

除了 PyTorch Lightning 的功能外,Quaterion 还提供了一个用于定义以下内容的支架

  • 可微调的相似度学习模型 - 编码器和头部层

  • 用于表示相似度信息的数据集和数据加载器

  • 用于相似度学习的损失函数

  • 用于评估模型性能的指标

你需要了解一些概念才能开始使用 Quaterion

相似度样本和数据加载器

与传统的分类或回归不同,相似度学习不使用特定的目标值。相反,它依赖于对象之间的相似度信息。

Quaterion 提供了两种主要方法来表示这种“相似度”信息。

相似度对

SimilarityPairSample - 是一个用于表示对象之间成对相似度的 dataclass。

例如,如果你想训练一个食物相似度模型

data = [
    SimilarityPairSample(obj_a="cheesecake", obj_b="muffins", score=1.0),
    SimilarityPairSample(obj_a="cheesecake", obj_b="macaroons", score=1.0),
    SimilarityPairSample(obj_a="cheesecake", obj_b="candies", score=1.0),
    SimilarityPairSample(obj_a="lemon", obj_b="lime", score=1.0),
    SimilarityPairSample(obj_a="lemon", obj_b="orange", score=1.0),
]

当然,你还需要有负样本 - 有几种策略可以做到这一点

  • 或者显式指定负样本

negative_data = [
    SimilarityPairSample(obj_a="cheesecake", obj_b="lemon", score=0.0),
    SimilarityPairSample(obj_a="orange", obj_b="macaroons", score=0.0),
    SimilarityPairSample(obj_a="lime", obj_b="candies", score=0.0)
]
  • 或者通过使用子组让 quaterion 假设所有其他样本对都是负样本

data = [
    SimilarityPairSample(obj_a="cheesecake", obj_b="muffins", score=1.0, subgroup=10),
    SimilarityPairSample(obj_a="cheesecake", obj_b="macaroons", score=1.0, subgroup=10),
    SimilarityPairSample(obj_a="cheesecake", obj_b="candies", score=1.0, subgroup=10),
    SimilarityPairSample(obj_a="lemon", obj_b="lime", score=1.0, subgroup=11),
    SimilarityPairSample(obj_a="lemon", obj_b="orange", score=1.0, subgroup=11),
]

Quaterion 将假定所有具有不同子组的样本都是负样本。

相似度组SimilarityGroupSample

在以下场景中可能很有用

  • 在同一对象的多个表示上训练相似度。例如,同一辆汽车的多张照片。

  • 将标签转换为相似度样本 - 任何分类数据集都可以通过假设同一类别的对象相似而不同类别的对象不相似来转换为相似度数据集。

要使用 SimilarityGroupSample,你需要为属于同一组的对象分配相同的 group_id

示例

data = [
    SimilarityGroupSample(obj="elon_musk_1.jpg", group=555),
    SimilarityGroupSample(obj="elon_musk_2.jpg", group=555),
    SimilarityGroupSample(obj="elon_musk_3.jpg", group=555),
    SimilarityGroupSample(obj="leonard_nimoy_1.jpg", group=209),
    SimilarityGroupSample(obj="leonard_nimoy_2.jpg", group=209),
]

数据加载器SimilarityDataLoader 是一个知道如何正确处理 SimilaritySamples 的数据加载器。针对 SimilarityPairSampleSimilarityGroupSample 分别有 PairsSimilarityDataLoaderGroupSimilarityDataLoader

将你的数据集包装到 SimilarityDataLoader 的实现之一中,使其与相似度学习兼容

 import json

 from torch.utils.data import Dataset

 from quaterion.dataset.similarity_data_loader import (
   GroupSimilarityDataLoader,
   SimilarityGroupSample,
 )

# Consumes data in format:
# {"description": "the thing I use for soup", "label": "spoon"}
class JsonDataset(Dataset):
    def __init__(self, path: str):
        super().__init__()
        with open(path, "r") as f:
            self.data = [json.loads(line) for line in f.readlines()]

    def __getitem__(self, index: int) -> SimilarityGroupSample:
        item = self.data[index]
        return SimilarityGroupSample(obj=item, group=hash(item["label"]))

    def __len__(self) -> int:
        return len(self.data)

train_dataloader = GroupSimilarityDataLoader(JsonDataset('./my_data.json'), batch_size=128)
val_dataloader = GroupSimilarityDataLoader(JsonDataset('./my_data_val.json'), batch_size=128)

相似度模型和编码器SimilarityModel - 是一个模型类,用于管理所有可训练层。

相似度模型充当一个编码器,它由其他编码器和一个结合编码器组件输出的头部层组成。

┌─────────────────────────────────────┐
│SimilarityModel                      │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │Encoder 1│ │Encoder 2│ │Encoder 3│ │
│ └────┬────┘ └────┬────┘ └────┬────┘ │
│      │           │           │      │
│      └────────┐  │  ┌────────┘      │
│               │  │  │               │
│           ┌───┴──┴──┴───┐           │
│           │   concat    │           │
│           └──────┬──────┘           │
│                  │                  │
│           ┌──────┴──────┐           │
│           │    Head     │           │
│           └─────────────┘           │
└─────────────────────────────────────┘

每个编码器接收原始对象数据作为输入,并生成一个嵌入 - 一个固定长度的张量。

将原始输入数据转换为适合神经网络的张量的规则在每个编码器的 collate_fn 函数中单独定义。

让我们定义我们的简单编码器

from os.path import join
from os import makedirs
from typing import Any, Dict, List, Union

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling

import torch.nn as nn
from torch import Tensor

from quaterion_models.heads import EncoderHead, SkipConnectionHead
from quaterion_models.encoders import Encoder
from quaterion_models.types import CollateFnType

class DescriptionEncoder(Encoder):
   def __init__(self, transformer: Transformer, pooling: Pooling):
       super().__init__()
       self.transformer = transformer
       self.pooling = pooling
       self.encoder = nn.Sequential(self.transformer, self.pooling)

   @property
   def trainable(self) -> bool:
       return False # Disable weights update for this encoder

   @property
   def embedding_size(self) -> int:
       return self.transformer.get_word_embedding_dimension()

   def forward(self, batch) -> Tensor:
       return self.encoder(batch)["sentence_embedding"]

   def collate_descriptions(self, batch: List[Any]) -> Tensor:
       descriptions = [record['description'] for record in batch]
       return self.transformer.tokenize(descriptions)

   def get_collate_fn(self) -> CollateFnType:
       return self.collate_descriptions

    @staticmethod
    def _pooling_path(path: str) -> str:
        return join(path, "pooling")

    @staticmethod
    def _transformer_path(path: str) -> str:
        return join(path, "transformer")

    def save(self, output_path: str):
        transformer_path = self._transformer_path(output_path)
        makedirs(transformer_path, exist_ok=True)

        pooling_path = self._pooling_path(output_path)
        makedirs(pooling_path, exist_ok=True)

        self.transformer.save(transformer_path)
        self.pooling.save(pooling_path)

   @classmethod
   def load(cls, input_path: str) -> Encoder:
       transformer = Transformer.load(join(input_path, 'transformer'))
       pooling = Pooling.load(join(input_path, 'pooling'))
       return cls(transformer=transformer, pooling=pooling)

编码器使用预训练层 transformerpooling 进行初始化。预训练组件的初始化在 Encoder 类外部定义。编码器设计用于推理服务的一部分,因此将训练相关的代码放在外面非常重要。

可训练模型TrainableModel。它包含定义 SimilarityModel 内容以及训练参数的方法。

from quaterion.loss import SimilarityLoss, TripletLoss
from quaterion import Quaterion, TrainableModel

from torch.optim import Adam

class Model(TrainableModel):
   def __init__(self, lr: float):
       self._lr = lr
       super().__init__()

   def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
       pre_trained = SentenceTransformer("all-MiniLM-L6-v2")
       transformer, pooling = pre_trained[0], pre_trained[1]
       return DescriptionEncoder(transformer, pooling)

   def configure_head(self, input_embedding_size) -> EncoderHead:
       return SkipConnectionHead(input_embedding_size)

   def configure_loss(self) -> SimilarityLoss:
       return TripletLoss()

   def configure_optimizers(self):
       return Adam(self.model.parameters(), lr=self._lr)

TrainableModelpl.LightningModule 的后代,具有相同的功能。

训练Quaterion.fit 中进行。

model = Model(lr=0.01)

Quaterion.fit(
    trainable_model=model,
    trainer=None, # Use default trainer
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader
)

在最简单的情况下,我们可以使用默认的 trainer。你很可能需要更改训练参数,在这种情况下,我们建议覆盖默认的 trainer 参数

import pytorch_lightning as pl

trainer_kwargs = Quaterion.trainer_defaults()
trainer_kwargs['min_epochs'] = 10
trainer_kwargs['callbacks'].append(YourCustomCallback())
trainer = pl.Trainer(**trainer_kwargs)

Quaterion.fit(
    trainable_model=model,
    trainer=trainer, # Use custom trainer
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader
)

在 Pytorch Lightning 文档中阅读更多关于 pl.Trainer 的信息

训练完成后,我们可以保存 SimilarityModel 用于服务

model.save_servable("./my_similarity_model")

延伸阅读
  • 最小工作示例

  • 如需更深入的了解,请查看我们的端到端教程。

    框架高级功能教程

    Qdrant

    了解更多关于 Qdrant 向量搜索项目和生态系统的信息

    发现 Qdrant

    相似度学习

    探索使用相似度学习解决实际问题

    学习相似度学习

    社区

    找到处理类似问题的人并获得问题解答

    加入社区