nip.image_classification.data.ImageClassificationDataset#
- class nip.image_classification.data.ImageClassificationDataset(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, train: bool = True)[source]#
A dataset for the image classification task.
Uses a torchvision dataset, and removes all the classes apart from two (determined by
hyper_params.image_classification.selected_classes
).Shapes
The dataset is a TensorDict with the following keys:
“image” (dataset_size num_channels height width): The images in the dataset.
“x” (dataset_size max_message_rounds height width): The pixel features, which are all zeros.
“y” (dataset_size): The labels of the images.
Methods Summary
__getitem__
(index)__getitems__
(index)__init__
(hyper_params, settings, ...[, train])__len__
()__repr__
()Return repr(self).
Download the raw data.
_get_pretrained_cache_dir
(model_name)Get the path to the directory with the cached pretrained embeddings.
_get_pretrained_metadata_path
(model_name)Get the path to the metadata file for the pretrained embeddings.
_get_pretrained_mmap_path
(model_name)Get the path to the memory-mapped tensor for the pretrained embeddings.
add_pretrained_embeddings
(model_name, ...[, ...])Add pretrained embeddings to the dataset and cache them.
Build the dataset as a TensorDict from the raw data.
build_torch_dataset
(*, transform)Build the TorchVision dataset.
get_pretrained_embedding_dtype
(model_name)Get the dtype of the embeddings for a pretrained model.
Get the feature shape of the embeddings for a pretrained model.
load_pretrained_embeddings
(model_name)Load cached embeddings for a pretrained model.
Attributes
binarification_method
The method used to binarify the dataset.
binarification_seed
The seed to use for shuffling the dataset before merging.
device
The device on which the dataset is stored.
instance_keys
The keys specifying the input instance.
keys
The keys (field names) in the dataset.
pretrained_embeddings_dir
The path to the directory containing cached pretrained model embeddings.
pretrained_model_names
The names of the pretrained models for which we have computed embeddings.
processed_dir
The path to the directory containing the processed data.
raw_dir
The path to the directory containing the raw data.
selected_classes
The two classes selected for binary classification.
x_dtype
y_dtype
Methods
- __getitem__(index: None | int | slice | str | Tensor | List[Any] | Tuple[Any, ...]) TensorDict | Tensor [source]#
- __getitems__(index: None | int | slice | str | Tensor | List[Any] | Tuple[Any, ...]) TensorDict | Tensor [source]#
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, train: bool = True)[source]#
- _get_pretrained_cache_dir(model_name: str) Path [source]#
Get the path to the directory with the cached pretrained embeddings.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
cache_dir (Path) – The path to the cache directory.
- _get_pretrained_metadata_path(model_name: str) Path [source]#
Get the path to the metadata file for the pretrained embeddings.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
metadata_path (Path) – The path to the metadata file.
- _get_pretrained_mmap_path(model_name: str) Path [source]#
Get the path to the memory-mapped tensor for the pretrained embeddings.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
mmap_path (Path) – The path to the memory-mapped tensor.
- add_pretrained_embeddings(model_name: str, full_embeddings: Tensor, overwrite_cache: bool = False)[source]#
Add pretrained embeddings to the dataset and cache them.
- build_tensor_dict() TensorDict [source]#
Build the dataset as a TensorDict from the raw data.
- Returns:
dataset (TensorDict) – The dataset as a TensorDict, with the keys “image”, “x”, and “y”.
- build_torch_dataset(*, transform: Any | None) TorchVisionDatasetWrapper [source]#
Build the TorchVision dataset.
- Parameters:
transform (Optional[Any]) – The transform to apply to the images.
- Returns:
dataset (TorchVisionDatasetWrapper) – The TorchVision dataset.
- get_pretrained_embedding_dtype(model_name: str) dtype [source]#
Get the dtype of the embeddings for a pretrained model.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
dtype (torch.dtype) – The dtype of the embeddings.
- get_pretrained_embedding_feature_shape(model_name: str) Size [source]#
Get the feature shape of the embeddings for a pretrained model.
The feature shape is the tuple of dimensions of the embeddings excluding the batch dimension.
- Parameters:
model_name (str) – The name of the pretrained model.
- Returns:
shape (torch.Size) – The shape of the embeddings.
- load_pretrained_embeddings(model_name: str)[source]#
Load cached embeddings for a pretrained model.
- Parameters:
model_name (str) – The name of the pretrained model.
- Raises:
CachedPretrainedEmbeddingsNotFound – If the cached embeddings are not found.