quaterion.loss.contrastive_loss 模块¶
- class ContrastiveLoss(distance_metric_name: Distance = Distance.COSINE, margin: float = 0.5, size_average: bool = True)[source]¶
Bases:
PairwiseLoss
对比损失。
输入预期为两个文本和一个标签,标签值为 0 或 1。如果标签 == 1,则两个嵌入之间的距离减小。如果标签 == 0,则嵌入之间的距离增加。
- 参数:
- forward(embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs) Tensor [source]¶
计算损失值。
- 参数:
embeddings – 嵌入批次,嵌入的前半部分是配对中第一个对象的嵌入,后半部分是配对中第二个对象的嵌入。
pairs – 配对中对应对象的索引。
labels – 正样本和负样本的得分。
subgroups – 用于区分可作为负样本和不可作为负样本的子组
**kwargs – 损失调用泛化的附加关键字参数
- 返回:
Tensor – 平均或求和后的损失值
- get_config_dict() Dict[str, Any] [source]¶
用于保存和加载的配置。
配置对象必须是 JSON 可序列化的。
- 返回:
Dict[str, Any] – JSON 可序列化的参数字典
- training: bool¶