nip.image_classification.pretrained_models.Resnet18Cifar10PretrainedModel#
- class nip.image_classification.pretrained_models.Resnet18Cifar10PretrainedModel(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
Resnet18 model trained on CIFAR-10.
Methods Summary
__init__
(hyper_params, settings)_get_transform
(model)Get the transform to apply to images before passing them to the model.
generate_dataset_embeddings
(datasets[, ...])Load the model and generate embeddings for the datasets.
Attributes
allow_other_datasets
base_model_name
dataset
embedding_channels
embedding_downscale_factor
embedding_height
embedding_width
name
timm_uri
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
- _get_transform(model: ResNet) Any | None [source]#
Get the transform to apply to images before passing them to the model.
- Returns:
transform (torchvision transform or None) – The transform to apply to images before passing them to the model
- generate_dataset_embeddings(datasets: dict[str, ImageClassificationDataset], delete_model: bool = True) dict[str, Tensor] [source]#
Load the model and generate embeddings for the datasets.
- Parameters:
datasets (dict[str, ImageClassificationDataset]) – The datasets to generate embeddings for
delete_model (bool, default=True) – Whether to delete the model after generating the embeddings
- Returns:
embeddings (dict[str, Tensor]) – The embeddings for each dataset