quaterion.loss.extras.pytorch_metric_learning_wrapper 模块¶
- 类 PytorchMetricLearningWrapper(loss: BaseMetricLossFunction, miner: BaseMiner | None = None)[source]¶
继承:
GroupLoss
提供一个简单的封装器,以便能够使用来自 pytorch-metric-learning 的损失函数和样本挖掘器。
您需要自己创建损失函数(以及可选的样本挖掘器)实例,并将这些实例传递给此封装器的构造函数。
注意
这是一个实验性功能,可能会发生更改、废弃或移除。
- 参数:
loss – 继承自 pytorch_metric_learning.losses.BaseMetricLossFunction 的损失对象实例。
miner – 继承自 pytorch_metric_learning.miners.BaseMetric 的样本挖掘器对象实例。
示例
class MyTrainableModel(quaterion.TrainableModel): ... def configure_loss(self): loss = pytorch_metric_learning.losses.TripletMarginLoss() miner = pytorch_metric_learning.miner.MultiSimilarityMiner() return quaterion.loss.PytorchMetricLearningWrapper(loss, miner)
- forward(embeddings, groups)[source]¶
- 参数:
embeddings – 形状: (batch_size, vector_length)
groups – 形状: (batch_size,) - 组,与 embeddings 关联
- 返回值:
Tensor – 零大小张量,损失值
- training: bool¶