nip.image_classification.pretrained_models.Resnet18PretrainedModel#
- class nip.image_classification.pretrained_models.Resnet18PretrainedModel(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
Base class for pretrained ResNet models using PyTorch Image Models (timm).
These models are hosted on Hugging Face and are loaded using the PyTorch Image Models (timm) library.
Derived classes should define the class attributes below.
- Parameters:
hyper_params (HyperParameters) – The parameters for the experiment
settings (ExperimentSettings) – The settings for the experiment
attributes (Class)
----------------
dataset (str) – The name of the dataset the model was trained for
allow_other_datasets (bool, default=False) – Whether the model can be used for datasets other than the one it was trained on
Methods Summary
__init__
(hyper_params, settings)_get_transform
(model)Get the transform to apply to images before passing them to the model.
generate_dataset_embeddings
(datasets[, ...])Load the model and generate embeddings for the datasets.
Attributes
allow_other_datasets
base_model_name
dataset
embedding_channels
embedding_downscale_factor
embedding_height
A decorator to create a class property.
embedding_width
A decorator to create a class property.
name
timm_uri
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
- _get_transform(model: ResNet) Any | None [source]#
Get the transform to apply to images before passing them to the model.
- Returns:
transform (torchvision transform or None) – The transform to apply to images before passing them to the model
- generate_dataset_embeddings(datasets: dict[str, ImageClassificationDataset], delete_model: bool = True) dict[str, Tensor] [source]#
Load the model and generate embeddings for the datasets.
- Parameters:
datasets (dict[str, ImageClassificationDataset]) – The datasets to generate embeddings for
delete_model (bool, default=True) – Whether to delete the model after generating the embeddings
- Returns:
embeddings (dict[str, Tensor]) – The embeddings for each dataset