quaterion.utils.utils 模块¶
- get_anchor_negative_mask(labels_a: Tensor, labels_b: Tensor | None = None) BoolTensor [source]¶
创建有效的锚点-负样本对的 2D 掩码。
- get_anchor_positive_mask(labels_a: Tensor, labels_b: Tensor | None = None) BoolTensor [source]¶
创建有效的锚点-正样本对的 2D 掩码。
- get_masked_maximum(dists: Tensor, mask: Tensor, dim: int = 1) Tensor [source]¶
用于半难负样本挖掘的实用函数。
- 参数:
dists – 平铺的距离矩阵。
mask – 平铺的掩码。
dim – 操作维度。
- 返回:
torch.Tensor - 掩码最大值。
- get_masked_minimum(dists, mask, dim=1)[source]¶
用于半难负样本挖掘的实用函数。
- 参数:
dists – 平铺的距离矩阵。
mask – 平铺的掩码。
dim – 操作维度。
- 返回:
torch.Tensor - 掩码最大值。
- get_triplet_mask(labels: Tensor) Tensor [source]¶
为 batch-all 策略创建有效三元组的 3D 掩码。
给定形状为 (batch_size,) 的一批标签,可以形成的三元组数量为 batch_size^3,即 batch_size 的立方,可以用形状为 (batch_size, batch_size, batch_size) 的张量表示。然而,一个三元组是有效的,如果:labels[i] == labels[j] 且 labels[i] != labels[k],并且 i、j 和 k 是不同的索引。此函数根据上述给定标准计算一个掩码,指示所有可能的三元组中有哪些是实际有效的。
- 参数:
labels (Tensor) – 与批次中的嵌入相关的标签。形状:(batch_size,)
- 返回:
torch.Tensor – 三元组掩码。形状:(batch_size, batch_size, batch_size)
- info_value_of_dtype(dtype: dtype) finfo | iinfo [source]¶
返回给定 PyTorch 数据类型的 finfo 或 iinfo 对象。
不允许使用 torch.bool。
- 参数:
dtype – 需要返回信息值的 dtype
- 返回:
Union[torch.finfo, torch.iinfo] – 关于给定数据类型的信息
- 抛出:
TypeError – 如果传入 torch.bool,则抛出 TypeError
- iter_by_batch(sequence: Sized | Iterable | Dataset, batch_size: int, log_progress: bool = True)[source]¶
按批次遍历可索引对象或可迭代对象
尝试按索引遍历,如果失败,则通过可迭代接口遍历。
- l2_norm(inputs: Tensor, dim: int = 0) Tensor [source]¶
对张量应用 L2 范数归一化
- 参数:
inputs – 输入张量。
dim – 操作维度。
- 返回:
torch.Tensor – L2 范数归一化的张量