nip.image_classification.data.ImageClassificationDataset#

class nip.image_classification.data.ImageClassificationDataset(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, train: bool = True)[source]#

A dataset for the image classification task.

Uses a torchvision dataset, and removes all the classes apart from two (determined by hyper_params.image_classification.selected_classes).

Shapes

The dataset is a TensorDict with the following keys:

  • “image” (dataset_size num_channels height width): The images in the dataset.

  • “x” (dataset_size max_message_rounds height width): The pixel features, which are all zeros.

  • “y” (dataset_size): The labels of the images.

Methods Summary

__getitem__(index)

__getitems__(index)

__init__(hyper_params, settings, ...[, train])

__len__()

__repr__()

Return repr(self).

_download()

Download the raw data.

_get_pretrained_cache_dir(model_name)

Get the path to the directory with the cached pretrained embeddings.

_get_pretrained_metadata_path(model_name)

Get the path to the metadata file for the pretrained embeddings.

_get_pretrained_mmap_path(model_name)

Get the path to the memory-mapped tensor for the pretrained embeddings.

add_pretrained_embeddings(model_name, ...[, ...])

Add pretrained embeddings to the dataset and cache them.

build_tensor_dict()

Build the dataset as a TensorDict from the raw data.

build_torch_dataset(*, transform)

Build the TorchVision dataset.

get_pretrained_embedding_dtype(model_name)

Get the dtype of the embeddings for a pretrained model.

get_pretrained_embedding_feature_shape(...)

Get the feature shape of the embeddings for a pretrained model.

load_pretrained_embeddings(model_name)

Load cached embeddings for a pretrained model.

Attributes

binarification_method

The method used to binarify the dataset.

binarification_seed

The seed to use for shuffling the dataset before merging.

device

The device on which the dataset is stored.

instance_keys

The keys specifying the input instance.

keys

The keys (field names) in the dataset.

pretrained_embeddings_dir

The path to the directory containing cached pretrained model embeddings.

pretrained_model_names

The names of the pretrained models for which we have computed embeddings.

processed_dir

The path to the directory containing the processed data.

raw_dir

The path to the directory containing the raw data.

selected_classes

The two classes selected for binary classification.

x_dtype

y_dtype

Methods

__getitem__(index: None | int | slice | str | Tensor | List[Any] | Tuple[Any, ...]) TensorDict | Tensor[source]#
__getitems__(index: None | int | slice | str | Tensor | List[Any] | Tuple[Any, ...]) TensorDict | Tensor[source]#
__init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, train: bool = True)[source]#
__len__() int[source]#
__repr__() str[source]#

Return repr(self).

_download()[source]#

Download the raw data.

_get_pretrained_cache_dir(model_name: str) Path[source]#

Get the path to the directory with the cached pretrained embeddings.

Parameters:

model_name (str) – The name of the pretrained model.

Returns:

cache_dir (Path) – The path to the cache directory.

_get_pretrained_metadata_path(model_name: str) Path[source]#

Get the path to the metadata file for the pretrained embeddings.

Parameters:

model_name (str) – The name of the pretrained model.

Returns:

metadata_path (Path) – The path to the metadata file.

_get_pretrained_mmap_path(model_name: str) Path[source]#

Get the path to the memory-mapped tensor for the pretrained embeddings.

Parameters:

model_name (str) – The name of the pretrained model.

Returns:

mmap_path (Path) – The path to the memory-mapped tensor.

add_pretrained_embeddings(model_name: str, full_embeddings: Tensor, overwrite_cache: bool = False)[source]#

Add pretrained embeddings to the dataset and cache them.

Parameters:
  • model_name (str) – The name of the pretrained model.

  • full_embeddings (Tensor) – The embeddings generated from the full original dataset, before any rearrangement or filtering.

  • overwrite_cache (bool, default=False) – Whether to overwrite the cached embeddings if they already exist.

build_tensor_dict() TensorDict[source]#

Build the dataset as a TensorDict from the raw data.

Returns:

dataset (TensorDict) – The dataset as a TensorDict, with the keys “image”, “x”, and “y”.

build_torch_dataset(*, transform: Any | None) TorchVisionDatasetWrapper[source]#

Build the TorchVision dataset.

Parameters:

transform (Optional[Any]) – The transform to apply to the images.

Returns:

dataset (TorchVisionDatasetWrapper) – The TorchVision dataset.

get_pretrained_embedding_dtype(model_name: str) dtype[source]#

Get the dtype of the embeddings for a pretrained model.

Parameters:

model_name (str) – The name of the pretrained model.

Returns:

dtype (torch.dtype) – The dtype of the embeddings.

get_pretrained_embedding_feature_shape(model_name: str) Size[source]#

Get the feature shape of the embeddings for a pretrained model.

The feature shape is the tuple of dimensions of the embeddings excluding the batch dimension.

Parameters:

model_name (str) – The name of the pretrained model.

Returns:

shape (torch.Size) – The shape of the embeddings.

load_pretrained_embeddings(model_name: str)[source]#

Load cached embeddings for a pretrained model.

Parameters:

model_name (str) – The name of the pretrained model.

Raises:

CachedPretrainedEmbeddingsNotFound – If the cached embeddings are not found.