nip.image_classification.pretrained_models.Resnet18Cifar10PretrainedModel#

class nip.image_classification.pretrained_models.Resnet18Cifar10PretrainedModel(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#

Resnet18 model trained on CIFAR-10.

Methods Summary

__init__(hyper_params, settings)

_get_transform(model)

Get the transform to apply to images before passing them to the model.

generate_dataset_embeddings(datasets[, ...])

Load the model and generate embeddings for the datasets.

Attributes

allow_other_datasets

base_model_name

dataset

embedding_channels

embedding_downscale_factor

embedding_height

embedding_width

name

timm_uri

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
_get_transform(model: ResNet) Any | None[source]#

Get the transform to apply to images before passing them to the model.

Returns:

transform (torchvision transform or None) – The transform to apply to images before passing them to the model

generate_dataset_embeddings(datasets: dict[str, ImageClassificationDataset], delete_model: bool = True) dict[str, Tensor][source]#

Load the model and generate embeddings for the datasets.

Parameters:
  • datasets (dict[str, ImageClassificationDataset]) – The datasets to generate embeddings for

  • delete_model (bool, default=True) – Whether to delete the model after generating the embeddings

Returns:

embeddings (dict[str, Tensor]) – The embeddings for each dataset