nip.scenario_base.data.TensorDictDataset#
- class nip.scenario_base.data.TensorDictDataset(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, train: bool = True)[source]#
Base class for datasets based on tensordicts.
The dataset is stored a as a memory-mapped tensordict. See https://pytorch.org/tensordict/saving.html
To subclass, implement the following methods:
raw_dir
(property): The path to the directory containing the raw data.processed_dir
(property): The path to the directory containing the processed data._build_tensor_dict
: Build the tensordict from the raw data._download
(optional): Download the raw data.
- Parameters:
hyper_params (HyperParameters) – The parameters for the experiment.
settings (ExperimentSettings) – The settings for the experiment.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
train (bool) – Whether to load the training or test set.
Methods Summary
__getitem__
(index)__getitems__
(index)__init__
(hyper_params, settings, ...[, train])__len__
()__repr__
()Return repr(self).
Download the raw data.
_get_pretrained_cache_dir
(model_name)Get the path to the directory with the cached pretrained embeddings.
_get_pretrained_metadata_path
(model_name)Get the path to the metadata file for the pretrained embeddings.
_get_pretrained_mmap_path
(model_name)Get the path to the memory-mapped tensor for the pretrained embeddings.
add_pretrained_embeddings
(model_name, ...[, ...])Add pretrained embeddings to the dataset and cache them.
Build the tensordict from the raw data.
build_torch_dataset
(**kwargs)Build the base PyTorch dataset, from which the tensordict is constructed.
get_pretrained_embedding_dtype
(model_name)Get the dtype of the embeddings for a pretrained model.
Get the feature shape of the embeddings for a pretrained model.
load_pretrained_embeddings
(model_name)Load cached embeddings for a pretrained model.
Attributes
device
The device on which the dataset is stored.
instance_keys
The keys specifying the input instance.
keys
The keys (field names) in the dataset.
pretrained_embeddings_dir
The path to the directory containing cached pretrained model embeddings.
pretrained_model_names
The names of the pretrained models for which we have computed embeddings.
processed_dir
The path to the directory containing the processed data.
raw_dir
The path to the directory containing the raw data.
Methods
- __getitem__(index: None | int | slice | str | Tensor | List[Any] | Tuple[Any, ...]) TensorDict | Tensor [source]#
- __getitems__(index: None | int | slice | str | Tensor | List[Any] | Tuple[Any, ...]) TensorDict | Tensor [source]#
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, train: bool = True)[source]#
- _get_pretrained_cache_dir(model_name: str) Path [source]#
Get the path to the directory with the cached pretrained embeddings.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
cache_dir (Path) – The path to the cache directory.
- _get_pretrained_metadata_path(model_name: str) Path [source]#
Get the path to the metadata file for the pretrained embeddings.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
metadata_path (Path) – The path to the metadata file.
- _get_pretrained_mmap_path(model_name: str) Path [source]#
Get the path to the memory-mapped tensor for the pretrained embeddings.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
mmap_path (Path) – The path to the memory-mapped tensor.
- add_pretrained_embeddings(model_name: str, full_embeddings: Tensor, overwrite_cache: bool = False)[source]#
Add pretrained embeddings to the dataset and cache them.
- abstract build_tensor_dict() TensorDict [source]#
Build the tensordict from the raw data.
- build_torch_dataset(**kwargs) Dataset [source]#
Build the base PyTorch dataset, from which the tensordict is constructed.
The implementation of this method is optional, but is required for using pretrained models because there we need direct access to the raw dataset.
- Parameters:
**kwargs – Additional keyword arguments to pass to the dataset class.
- get_pretrained_embedding_dtype(model_name: str) dtype [source]#
Get the dtype of the embeddings for a pretrained model.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
dtype (torch.dtype) – The dtype of the embeddings.
- get_pretrained_embedding_feature_shape(model_name: str) Size [source]#
Get the feature shape of the embeddings for a pretrained model.
The feature shape is the tuple of dimensions of the embeddings excluding the batch dimension.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
shape (torch.Size) – The shape of the embeddings.
- load_pretrained_embeddings(model_name: str)[source]#
Load cached embeddings for a pretrained model.
- Parameters:
model_name (str) – The name of the pretrained model.
- Raises:
CachedPretrainedEmbeddingsNotFound – If the cached embeddings are not found.