Skip to content

PyTorch model roundtrip

PyTorch's nn.Module is fully picklable — architecture and weights travel together through Python's serialization protocol. Skyward uses cloudpickle under the hood, which means you can send an untrained model to a remote worker, train it there, and get the trained model back. No state_dict files, no checkpoints, no manual save/load — the model object itself is the transport.

This guide walks through the full cycle: build locally, train remotely, evaluate locally.

The model

A standard nn.Module — nothing special required for serialization:

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.classifier = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(self.features(x))

When cloudpickle serializes this object, it captures both the class definition and the instance state (all parameter tensors). On the remote side, it reconstructs the exact same object with the exact same weights.

Loading data locally

MNIST is loaded on the local machine and sent as tensors to the remote worker. The worker doesn't need torchvision — it only needs torch to work with the tensors it receives:

def load_mnist() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Load MNIST and flatten images to 784-d vectors."""
    from torchvision import datasets, transforms

    transform = transforms.ToTensor()
    train_ds = datasets.MNIST("/tmp/mnist", train=True, download=True, transform=transform)
    test_ds = datasets.MNIST("/tmp/mnist", train=False, download=True, transform=transform)

    x_train = train_ds.data.float().view(-1, 784) / 255.0
    y_train = train_ds.targets
    x_test = test_ds.data.float().view(-1, 784) / 255.0
    y_test = test_ds.targets
    return x_train, y_train, x_test, y_test

The 60k training images (each 28x28) are flattened to 784-d vectors and normalized. Both x_train and y_train are regular tensors — cloudpickle handles them the same way it handles any Python object.

The remote training function

The @sky.function function receives the model as an argument and returns it after training:

@sky.function
def train(
    model: nn.Module,
    x: torch.Tensor,
    y: torch.Tensor,
    epochs: int,
    lr: float,
) -> nn.Module:
    """Train the model on the remote worker and return it with learned weights."""
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader, TensorDataset

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    loader = DataLoader(
        TensorDataset(x.to(device), y.to(device)),
        batch_size=64,
        shuffle=True,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0

        for batch_x, batch_y in loader:
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (output.argmax(1) == batch_y).sum().item()
            total += batch_y.size(0)

        acc = 100.0 * correct / total
        print(f"  epoch {epoch + 1}/{epochs}: loss={total_loss / len(loader):.4f}, acc={acc:.1f}%")

    return model.cpu()

The type signature tells the story: nn.Module goes in, nn.Module comes out. The optimizer modifies the model's parameters in-place during training. The final model.cpu() ensures all tensors are on CPU before serialization — this matters when training on GPU, since CUDA tensors can't deserialize on a machine without a GPU.

Pinning the torch version

Torch tensors use pickle's __reduce_ex__ protocol to serialize their raw storage. The binary format can change between torch versions — a tensor pickled with torch 2.10 may not deserialize correctly on torch 2.8 (or vice versa). Since the model travels both directions through cloudpickle, the local and remote torch versions must match:

TORCH_VERSION = torch.__version__.split("+")[0]
image=sky.Image(
    pip=[f"torch=={TORCH_VERSION}"],
    pip_indexes=[
        sky.PipIndex(
            url="https://download.pytorch.org/whl/cpu",
            packages=["torch"],
        ),
    ],
),

TORCH_VERSION strips the build suffix (e.g. +cpu, +cu128) so the version pin works across wheel variants. The PipIndex scopes the PyTorch wheel index to the torch package only, preventing it from affecting other dependencies.

The full cycle

Build the model locally, evaluate it (random accuracy), train remotely, evaluate again:

if __name__ == "__main__":
    x_train, y_train, x_test, y_test = load_mnist()

    model = MNISTClassifier()
    print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
    print("Before training:")
    evaluate(model, x_test, y_test)

    with sky.Compute(
        provider=sky.Container(),
        vcpus=4,
        memory_gb=2,
        image=sky.Image(
            pip=[f"torch=={TORCH_VERSION}"],
            pip_indexes=[
                sky.PipIndex(
                    url="https://download.pytorch.org/whl/cpu",
                    packages=["torch"],
                ),
            ],
        ),
    ) as compute:
        print("\nTraining remotely...")
        trained_model: nn.Module = train(model, x_train, y_train, epochs=10, lr=1e-3) >> compute

    print("\nAfter training:")
    evaluate(trained_model, x_test, y_test)

The untrained model starts at ~10% accuracy (random chance for 10 classes). After >> pool dispatches the training to the remote worker and returns the trained model, local evaluation shows the learned accuracy — proving the weights survived the roundtrip.

Local evaluation

The evaluate function runs on your local machine using the model that came back from the cloud:

@torch.no_grad()
def evaluate(model: nn.Module, x: torch.Tensor, y: torch.Tensor) -> None:
    """Evaluate the trained model locally."""
    model.eval()
    output = model(x)
    predictions = output.argmax(1)
    accuracy = 100.0 * (predictions == y).float().mean().item()

    print(f"  test samples : {len(y)}")
    print(f"  test accuracy: {accuracy:.1f}%")

No reconstruction, no load_state_dict — the returned object is a regular MNISTClassifier instance with trained weights, ready for inference.

Run the full example

git clone https://github.com/gabfssilva/skyward.git
cd skyward
uv run python guides/15_torch_model_roundtrip.py

What you learned:

  • nn.Module is picklable — cloudpickle captures architecture + weights together, no manual serialization needed.
  • Models as arguments and return values — send an untrained model in, get a trained model back via >>.
  • Pin torch versions — the pickle format for tensors is version-sensitive; TORCH_VERSION keeps local and remote in sync.
  • model.cpu() before returning — ensures CUDA tensors don't break deserialization on CPU-only machines.
  • Data stays local until dispatch — load datasets on your machine, send as tensors; the worker only needs torch.