nip.image_classification.pretrained_models.Resnet18Cifar10PretrainedModel#
- class nip.image_classification.pretrained_models.Resnet18Cifar10PretrainedModel(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
Resnet18 model trained on CIFAR-10.
Methods Summary
__init__(hyper_params, settings)_get_transform(model)Get the transform to apply to images before passing them to the model.
generate_dataset_embeddings(datasets[, ...])Load the model and generate embeddings for the datasets.
Attributes
allow_other_datasetsbase_model_namedatasetembedding_channelsembedding_downscale_factorembedding_heightembedding_widthnametimm_uriMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings)[source]#
- _get_transform(model: ResNet) Any | None[source]#
Get the transform to apply to images before passing them to the model.
- Returns:
transform (torchvision transform or None) – The transform to apply to images before passing them to the model
- generate_dataset_embeddings(datasets: dict[str, ImageClassificationDataset], delete_model: bool = True) dict[str, Tensor][source]#
Load the model and generate embeddings for the datasets.
- Parameters:
datasets (dict[str, ImageClassificationDataset]) – The datasets to generate embeddings for
delete_model (bool, default=True) – Whether to delete the model after generating the embeddings
- Returns:
embeddings (dict[str, Tensor]) – The embeddings for each dataset