快捷方式

quaterion.loss.extras.pytorch_metric_learning_wrapper 模块

PytorchMetricLearningWrapper(loss: BaseMetricLossFunction, miner: BaseMiner | None = None)[source]

继承: GroupLoss

提供一个简单的封装器,以便能够使用来自 pytorch-metric-learning 的损失函数和样本挖掘器。

您需要自己创建损失函数(以及可选的样本挖掘器)实例,并将这些实例传递给此封装器的构造函数。

注意

这是一个实验性功能,可能会发生更改、废弃或移除。

注意

请参阅下方关于此封装器的快速使用示例,但请参考 pytorch-metric-learning 的文档以了解更多关于各自的 损失函数样本挖掘器 的信息。

参数:

示例

class MyTrainableModel(quaterion.TrainableModel):
    ...
    def configure_loss(self):
        loss = pytorch_metric_learning.losses.TripletMarginLoss()
        miner = pytorch_metric_learning.miner.MultiSimilarityMiner()
        return quaterion.loss.PytorchMetricLearningWrapper(loss, miner)
forward(embeddings, groups)[source]
参数:
  • embeddings – 形状: (batch_size, vector_length)

  • groups – 形状: (batch_size,) - 组,与 embeddings 关联

返回值:

Tensor – 零大小张量,损失值

training: bool

Qdrant

了解更多关于 Qdrant 向量搜索项目和生态系统的信息

探索 Qdrant

相似性学习

探索使用相似性学习解决实际问题

学习相似性学习

社区

找到处理类似问题的人,并获得您的问题的答案

加入社区