nip.image_classification.pretrained_models.Resnet18PretrainedModel#

class nip.image_classification.pretrained_models.Resnet18PretrainedModel(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#

Base class for pretrained ResNet models using PyTorch Image Models (timm).

These models are hosted on Hugging Face and are loaded using the PyTorch Image Models (timm) library.

Derived classes should define the class attributes below.

Parameters:
  • hyper_params (HyperParameters) – The parameters for the experiment

  • settings (ExperimentSettings) – The settings for the experiment

  • attributes (Class)

  • ----------------

  • dataset (str) – The name of the dataset the model was trained for

  • allow_other_datasets (bool, default=False) – Whether the model can be used for datasets other than the one it was trained on

Methods Summary

__init__(hyper_params, settings)

_get_transform(model)

Get the transform to apply to images before passing them to the model.

generate_dataset_embeddings(datasets[, ...])

Load the model and generate embeddings for the datasets.

Attributes

allow_other_datasets

base_model_name

dataset

embedding_channels

embedding_downscale_factor

embedding_height

A decorator to create a class property.

embedding_width

A decorator to create a class property.

name

timm_uri

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
_get_transform(model: ResNet) Any | None[source]#

Get the transform to apply to images before passing them to the model.

Returns:

transform (torchvision transform or None) – The transform to apply to images before passing them to the model

generate_dataset_embeddings(datasets: dict[str, ImageClassificationDataset], delete_model: bool = True) dict[str, Tensor][source]#

Load the model and generate embeddings for the datasets.

Parameters:
  • datasets (dict[str, ImageClassificationDataset]) – The datasets to generate embeddings for

  • delete_model (bool, default=True) – Whether to delete the model after generating the embeddings

Returns:

embeddings (dict[str, Tensor]) – The embeddings for each dataset