quaterion.train.xbm.xbm_buffer 模块¶
- 类 XbmBuffer(config: XbmConfig, embedding_size: int)[源码]¶
基类:
object
一个用于保存最近 N 个嵌入和目标值的缓冲区实现。
灵感来源于 https://github.com/msight-tech/research-xbm/blob/master/ret_benchmark/modeling/xbm.py
- 参数:
config – 用于配置 XBM 设置的 Config 类。
embedding_size – 为该模型配置的 EncoderHead 的输出维度。
- queue(embeddings: Tensor, targets: LongTensor) None [源码]¶
将批次嵌入和目标值排入缓冲区。
- 参数:
embeddings – 批次中的输出嵌入。
targets – 批次中的目标值。
- 属性 is_full: bool¶