quaterion.eval.base_metric 模块¶
- class BaseMetric(distance_metric_name: Distance = Distance.COSINE)[source]¶
基类:
object
评估指标的基类
提供距离矩阵计算的默认实现。
- 参数:
distance_metric_name – 用于计算距离或相似性矩阵的距离指标名称。可用名称可在
Distance
中找到。
- compute(*args, **kwargs) Tensor [source]¶
计算指标值
- 参数:
args –
metric. (*kwargs - 包含计算指标所需的嵌入和目标*) –
- 返回:
torch.Tensor - 计算出的指标值
- precompute(embeddings: Tensor, **targets) Tuple[Tensor, Tensor] [source]¶
准备计算所需的数据
基于组计算距离矩阵和最终标签。
- 参数:
embeddings – 用于计算指标值的嵌入
targets – 用于计算最终标签的对象
- 返回:
torch.Tensor, torch.Tensor - 标签和距离矩阵
- static prepare_labels(**targets) Tensor [source]¶
计算指标标签
- 参数:
**targets – 用于计算最终标签的对象。**targets 在 PairMetric 中包括 labels、pairs 和 subgroups,在 GroupMetric 中包括 groups。
- 返回:
*targets* – torch.Tensor - 在指标计算期间使用的标签
- raw_compute(distance_matrix: Tensor, labels: Tensor) Tensor [source]¶
在准备好的 distance_matrix 和 labels 上执行指标计算
此方法不进行任何数据和标签准备。假定 distance_matrix 已经计算,必要的更改(例如屏蔽元素到自身的距离)已经应用,并且相应的 labels 已经准备好。
- 参数:
distance_matrix – 准备好进行指标计算的距离矩阵
labels – 准备好进行指标计算且形状与 distance_matrix 相同的标签。对于 PairMetric,值取自 SimilarityPairSample.score;对于 GroupMetric,可能的值为 {0, 1}。
- 返回:
torch.Tensor - 计算出的指标值