How to load a checkpoint in a new model?
Loading a pre-trained model is a key skill when working in machine learning, especially when we want to add efficiency to our workflow. In this segment, we will learn how to leverage the checkpoints we have previously saved to initialize a new model and optimizer in PyTorch, ready for additional training or inference!
How to load the checkpoint of our model?
First we must download the checkpoint from the JobinFace Hub, where we have previously uploaded our weights. Once in our Google Colab environment, we use PyTorch's load
function to load this checkpoint. Here is the necessary code snippet:
checkpoint = torch.load('Weights/modelCheckpoint.pth').
How to initialize the new model?
Initializing the new model requires replicating the configuration of the original model. This includes specifying the number of classes, the vocabulary size and the embedding size. All of these must be identical to the original model to ensure that the loaded weights match properly.
num_classes = ...vocabulary_size = ...embedding_size = ...model2 = MyModel(num_classes, vocabulary_size, embedding_size)
Why do we need a new optimizer?
The optimizer is essential to adjust the model weights during training. Like the model, we must initialize a new optimizer and then load its states from the checkpoint.
optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.2)
How do we ensure that our model and optimizer apply the correct weights?
Load the weights into the model
We use the load_state_dict
method to load the model state from the checkpoint dictionary:
model2.load_state_dict(checkpoint['model_state_dict'])
Fixing errors with the optimizer
Sometimes, we may encounter errors when trying to load the optimizer state. If this happens, let's verify that we saved the optimizer state properly. If not, let's go back, correct it and upload the corrected version to the JobinFace Hub.
optimizer2.load_state_dict(checkpoint['optimizer_state_dict'])
How do we handle the rest of the workflow?
If we are going to continue the training, we will need to define the epoch state and the loss. These are also stored in the checkpoint and we can retrieve them in a similar way.
epoch2 = checkpoint['epoch']loss2 = checkpoint['loss']
How do we perform inference with the model loaded?
Preparation for inference
If the goal is inference, we generally take the model to the CPU, especially if the GPU is not needed, which is common for simple inference tasks.
model2.to('cpu')
Testing with a new example
It is always advisable to test our model with a new text to confirm that everything is set up correctly. We use a text pipeline to convert the text into a format that the model can process.
example = "text about garlic trees"result = model2(example)print(result)
By following these steps, we will ensure that our model is ready to perform inference with the previously trained weights and store any future adjustments directly in the JobinFace Hub. Continuing to experiment and tweak the model is essential to get the most out of it - keep discovering and learning new ways to apply this knowledge!
Want to see more contributions, questions and answers from the community?