Integrate with PyTorch¶
PyTorch is a popular open source machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing.
PyTorch enables fast, flexible experimentation and efficient production through a user-friendly front-end, distributed training, and ecosystem of tools and libraries.
Instrument PyTorch with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.
Note: If you are using Pytorch Tensorboard, see our Tensorboard Integration.
Note: This integration also supports PyTorch Distributed Data Parallel. See below.
Start logging¶
Connect Comet to your existing code by adding in a simple Comet Experiment.
Add the following lines of code to your script or notebook:
import comet_ml
import torch
import torchvision
experiment = comet_ml.Experiment(
api_key="<Your API Key>",
project_name="<Your Project Name>"
)
# Your code here
Note
There are other ways to configure Comet. See more here.
Log automatically¶
After an Experiment has been created, Comet automatically logs the following PyTorch items, by default, with no additional configuration:
- Model and graph description
- Training loss
You can easily turn the automatic logging on and off for any or all items. See Configure Comet for PyTorch for more details.
Note
Don't see what you need to log here? We have your back. You can manually log any kind of data to Comet using the Experiment object. For example, use experiment.log_image to log images, or experiment.log_audio to log audio.
End-to-end example¶
Following is a basic example of using Comet with PyTorch.
If you can't wait, check out the results of this example PyTorch experiment for a preview of what's to come.
Install dependencies¶
pip install comet_ml torch torchvision tqdm
Run the example¶
import comet_ml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
comet_ml.init(project_name="comet-example-pytorch")
experiment = comet_ml.Experiment()
hyper_params = {"batch_size": 100, "num_epochs": 2, "learning_rate": 0.01}
experiment.log_parameters(hyper_params)
# MNIST Dataset
dataset = datasets.MNIST(
root="./data/", train=True, transform=transforms.ToTensor(), download=True
)
# Data Loader (Input Pipeline)
dataloader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=hyper_params["batch_size"], shuffle=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(model, optimizer, criterion, dataloader, epoch):
model.train()
total_loss = 0
correct = 0
for batch_idx, (images, labels) in tqdm(enumerate(dataloader)):
optimizer.zero_grad()
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
pred = outputs.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
loss.backward()
optimizer.step()
# Compute train accuracy
batch_correct = pred.eq(labels.view_as(pred)).sum().item()
batch_total = labels.size(0)
total_loss += loss.item()
correct += batch_correct
# Log batch_accuracy to Comet; step is each batch
experiment.log_metric("batch_accuracy", batch_correct / batch_total)
total_loss /= len(dataloader.dataset)
correct /= len(dataloader.dataset)
experiment.log_metrics({"accuracy": correct, "loss": total_loss}, epoch=epoch)
model = Net().to(device)
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params["learning_rate"])
# Train the Model
with experiment.train():
print("Running Model Training")
for epoch in range(hyper_params["num_epochs"]):
train(model, optimizer, criterion, dataloader, epoch)
Try it out!¶
Don't just take our word for it, try it out for yourself.
- For more examples using PyTorch, see our examples GitHub repository.
- Run the end-to-end example above in Colab:
Pytorch model saving and loading¶
Comet provides user-friendly helpers to allow you to easily save your model and load them back.
Saving a model¶
To save a Pytorch model, you can use the comet_ml.integration.pytorch.log_model
helper like this:
from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model
experiment = Experiment()
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
...
def forward(self, x):
...
return x
# Initialize model
model = TheModelClass()
# Train model
train(model)
# Save the model for inference
log_model(experiment, model, model_name="TheModel")
The model file will be saved as an Experiment Model which is visible in the Experiment assets tab. From there you will be able to register it in the Model Registry.
The previous code snippet is tailored for inference needs. If you want to log a general checkpoint for Resume Training, you can update the last line of the snippet to be:
# Save the model for Resume Training
model_checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
...
}
log_model(experiment, model_checkpoint, model_name="TheModel")
comet_ml.integration.pytorch.log_model
is using torch.save
under the hood, consult the official Pytorch documentation for more details and for instructions for more advanced use-cases.
Check out the reference documentation for more details.
Loading a model¶
Once you have saved a model using comet_ml.integration.pytorch.log_model
, you can load it back with its counterpart comet_ml.integration.pytorch.load_model
.
Here is how you can load a model from the Model Registry for Inference:
from comet_ml.integration.pytorch import load_model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
...
def forward(self, x):
...
return x
# Initialize model
model = TheModelClass()
# Load the model state dict from Comet Registry
model.load_state_dict(load_model("registry://WORKSPACE/TheModel:1.2.4"))
model.eval()
prediction = model(...)
You can load Pytorch Model from various sources:
file://data/my-model
, load thestate_dict
from the file pathdata/my-model
(relative path)file:///path/to/my-model
, load thestate_dict
from the file path/path/to/-my-model
(absolute path)registry://<workspace>/<registry_name>
, load thestate_dict
from the Model Registry identified by the workspace and registry name, take the last version of it.registry://<workspace>/<registry_name>:version
, load thestate_dict
from the Model Registry identified by the workspace, registry name and explicit version.experiment://<experiment_key>/<model_name>
, load thestate_dict
from an Experiment, identified by the Experiment key and the model_name.experiment://<workspace>/<project_name>/<experiment_name>/<model_name>
, load thestate_dict
from an Experiment, identified by the workspace name, project name, experiment name and the model_name.
The previous code snippet is tailored for inference needs. If you want to load a general checkpoint for Resume Training, you can update the last line of the snippet to be:
# Initialize model
model = TheModelClass()
# Load the model state dict from a Comet Experiment
checkpoint = load_model("experiment://e1098c4e1e764ff89881b868e4c70f5/TheModel")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
comet_ml.integration.pytorch.load_modle
is using torch.load
under the hood, consult the official Pytorch documentation for more details and for instructions for more advanced use-cases.
Check out the reference documentation for more details.
PyTorch Distributed Data Parallel¶
Are you running distributed training with PyTorch? There is an example for logging PyTorch DDP with Comet in the comet-example repository.
Configure Comet for PyTorch¶
You can control which PyTorch items are logged automatically. Use any of the following methods:
experiment = comet_ml.Experiment(
log_graph=True, # Can be True or False.
auto_metric_logging=True # Can be True or False
)
Add or remove these fields from your .comet.config
file under the [comet_auto_log]
section to enable or disable logging.
[comet_auto_log]
graph=true # can be true or false
metrics=true # can be true or false
export COMET_AUTO_LOG_GRAPH=true # Can be true or false
export COMET_AUTO_LOG_METRICS=true # Can be true or false
For more information about configuring Comet, see Configure Comet.