quaterion.train.cache.cache_config 模块¶
- class CacheConfig(cache_type: ~quaterion.train.cache.cache_config.CacheType | None = CacheType.AUTO, mapping: ~typing.Dict[str, ~quaterion.train.cache.cache_config.CacheType] = <factory>, key_extractors: KeyExtractorType | ~typing.Dict[str, KeyExtractorType] = <factory>, batch_size: int | None = 32, num_workers: int | None = None, save_dir: str | None = None)[source]¶
基类:
object
确定缓存设置。
这个类应该被传递给
configure_caches()
- batch_size: int | None = 32¶
在缓存过程中用于 CacheDataLoader 的批大小。它不影响其他训练阶段。
- key_extractors: KeyExtractorType | Dict[str, KeyExtractorType]¶
编码器到键提取函数的映射,缓存不可哈希对象所需的。
- num_workers: int | None = None¶
在缓存过程中用于 CacheDataLoader 的 worker 数量。它不影响其他训练阶段。
- save_dir: str | None =None¶
如果提供此参数,缓存将保存到给定目录,并在多次启动之间重复使用。
- class CacheType(value)[source]¶
基类:
str
,Enum
用于缓存的可用 tensor 设备。
- AUTO = 'auto'¶
如果 CUDA 可用,则使用 CUDA,否则使用 CPU。
- CPU = 'cpu'¶
tensor 设备是 CPU。
- GPU = 'gpu'¶
tensor 设备是 GPU。
- NONE = 'none'¶
禁用缓存
- KeyExtractorType¶
从输入对象提取哈希值的函数类型。如果无法通过其他方式区分用于缓存的值,则需要此参数。
Callable
[[Any
],Hashable
]] 的别名