Course Content
PyTorch Essentials
PyTorch Essentials
Evaluation of the Model
Preparing for Evaluation
Before starting the evaluation process on the test set, you need to ensure the following:
-
Set the model to evaluation mode: use
model.eval()
to turn off features like dropout and batch normalization, ensuring consistent behavior during evaluation; -
Disable gradient tracking: use
torch.no_grad()
to save memory and speed up computations, as gradients are not required during evaluation.
# Set the model to evaluation mode
model.eval()
# Disable gradient computation for evaluation
with torch.no_grad():
# Forward pass on the test data
test_predictions = model(X_test)
Converting Predictions
As we've already mentioned previously, the output from the model will be logits (raw scores). To get the predicted class labels, we use torch.argmax
to extract the index of the maximum value along the class dimension.
# Convert logits to predicted class labels
predicted_labels = torch.argmax(test_predictions, dim=1)
Calculating Metrics
For classification problems, accuracy is a useful starting metric, provided the dataset is balanced.
# Calculate accuracy
correct_predictions = (predicted_labels == y_test).sum().item()
accuracy = correct_predictions / len(y_test) * 100
print(f"Test accuracy: {accuracy:.2f}%")
To gain deeper insights into model performance, you can calculate additional metrics such as precision, recall, and F1-score. You can learn more about these metrics and their formulas in this article, using their respective formulas.
Full Implementation
import torch import os os.system('wget https://staging-content-media-cdn.codefinity.com/courses/1dd2b0f6-6ec0-40e6-a570-ed0ac2209666/section_3/model_training.py 2>/dev/null') from model_training import model, X_test, y_test # Set model to evaluation mode model.eval() # Disable gradient tracking with torch.no_grad(): # Forward pass test_predictions = model(X_test) # Get predicted classes predicted_labels = torch.argmax(test_predictions, dim=1) # Calculate accuracy correct_predictions = (predicted_labels == y_test).sum().item() accuracy = correct_predictions / len(y_test) * 100 print(f"Test accuracy: {accuracy:.2f}%")
Thanks for your feedback!