PyTorch model roundtrip¶
PyTorch's nn.Module is fully picklable — architecture and weights travel together through Python's serialization protocol. Skyward uses cloudpickle under the hood, which means you can send an untrained model to a remote worker, train it there, and get the trained model back. No state_dict files, no checkpoints, no manual save/load — the model object itself is the transport.
This guide walks through the full cycle: build locally, train remotely, evaluate locally.
The model¶
A standard nn.Module — nothing special required for serialization:
class MNISTClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
)
self.classifier = nn.Linear(128, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.classifier(self.features(x))
When cloudpickle serializes this object, it captures both the class definition and the instance state (all parameter tensors). On the remote side, it reconstructs the exact same object with the exact same weights.
Loading data locally¶
MNIST is loaded on the local machine and sent as tensors to the remote worker. The worker doesn't need torchvision — it only needs torch to work with the tensors it receives:
def load_mnist() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Load MNIST and flatten images to 784-d vectors."""
from torchvision import datasets, transforms
transform = transforms.ToTensor()
train_ds = datasets.MNIST("/tmp/mnist", train=True, download=True, transform=transform)
test_ds = datasets.MNIST("/tmp/mnist", train=False, download=True, transform=transform)
x_train = train_ds.data.float().view(-1, 784) / 255.0
y_train = train_ds.targets
x_test = test_ds.data.float().view(-1, 784) / 255.0
y_test = test_ds.targets
return x_train, y_train, x_test, y_test
The 60k training images (each 28x28) are flattened to 784-d vectors and normalized. Both x_train and y_train are regular tensors — cloudpickle handles them the same way it handles any Python object.
The remote training function¶
The @sky.function function receives the model as an argument and returns it after training:
@sky.function
def train(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
epochs: int,
lr: float,
) -> nn.Module:
"""Train the model on the remote worker and return it with learned weights."""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
loader = DataLoader(
TensorDataset(x.to(device), y.to(device)),
batch_size=64,
shuffle=True,
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0.0
correct = 0
total = 0
for batch_x, batch_y in loader:
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (output.argmax(1) == batch_y).sum().item()
total += batch_y.size(0)
acc = 100.0 * correct / total
print(f" epoch {epoch + 1}/{epochs}: loss={total_loss / len(loader):.4f}, acc={acc:.1f}%")
return model.cpu()
The type signature tells the story: nn.Module goes in, nn.Module comes out. The optimizer modifies the model's parameters in-place during training. The final model.cpu() ensures all tensors are on CPU before serialization — this matters when training on GPU, since CUDA tensors can't deserialize on a machine without a GPU.
Pinning the torch version¶
Torch tensors use pickle's __reduce_ex__ protocol to serialize their raw storage. The binary format can change between torch versions — a tensor pickled with torch 2.10 may not deserialize correctly on torch 2.8 (or vice versa). Since the model travels both directions through cloudpickle, the local and remote torch versions must match:
image=sky.Image(
pip=[f"torch=={TORCH_VERSION}"],
pip_indexes=[
sky.PipIndex(
url="https://download.pytorch.org/whl/cpu",
packages=["torch"],
),
],
),
TORCH_VERSION strips the build suffix (e.g. +cpu, +cu128) so the version pin works across wheel variants. The PipIndex scopes the PyTorch wheel index to the torch package only, preventing it from affecting other dependencies.
The full cycle¶
Build the model locally, evaluate it (random accuracy), train remotely, evaluate again:
if __name__ == "__main__":
x_train, y_train, x_test, y_test = load_mnist()
model = MNISTClassifier()
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
print("Before training:")
evaluate(model, x_test, y_test)
with sky.Compute(
provider=sky.Container(),
vcpus=4,
memory_gb=2,
image=sky.Image(
pip=[f"torch=={TORCH_VERSION}"],
pip_indexes=[
sky.PipIndex(
url="https://download.pytorch.org/whl/cpu",
packages=["torch"],
),
],
),
) as compute:
print("\nTraining remotely...")
trained_model: nn.Module = train(model, x_train, y_train, epochs=10, lr=1e-3) >> compute
print("\nAfter training:")
evaluate(trained_model, x_test, y_test)
The untrained model starts at ~10% accuracy (random chance for 10 classes). After >> pool dispatches the training to the remote worker and returns the trained model, local evaluation shows the learned accuracy — proving the weights survived the roundtrip.
Local evaluation¶
The evaluate function runs on your local machine using the model that came back from the cloud:
@torch.no_grad()
def evaluate(model: nn.Module, x: torch.Tensor, y: torch.Tensor) -> None:
"""Evaluate the trained model locally."""
model.eval()
output = model(x)
predictions = output.argmax(1)
accuracy = 100.0 * (predictions == y).float().mean().item()
print(f" test samples : {len(y)}")
print(f" test accuracy: {accuracy:.1f}%")
No reconstruction, no load_state_dict — the returned object is a regular MNISTClassifier instance with trained weights, ready for inference.
Run the full example¶
git clone https://github.com/gabfssilva/skyward.git
cd skyward
uv run python guides/15_torch_model_roundtrip.py
What you learned:
nn.Moduleis picklable — cloudpickle captures architecture + weights together, no manual serialization needed.- Models as arguments and return values — send an untrained model in, get a trained model back via
>>. - Pin torch versions — the pickle format for tensors is version-sensitive;
TORCH_VERSIONkeeps local and remote in sync. model.cpu()before returning — ensures CUDA tensors don't break deserialization on CPU-only machines.- Data stays local until dispatch — load datasets on your machine, send as tensors; the worker only needs
torch.