← Todos os posts

8 de abril de 2026

Treinando um ViT no MNIST com GPUs na nuvem usando Skyward

Gabriel Francisco gabriel.francisco@usp.br

Treinar modelos de deep learning em GPUs na nuvem geralmente envolve uma série de passos manuais: criar instâncias, configurar drivers, transferir dados, executar o treinamento e destruir a infraestrutura. O Skyward é uma biblioteca Python que abstrai tudo isso em uma única API.

Neste post, mostramos como treinar um Vision Transformer (ViT) no dataset MNIST usando GPUs provisionadas automaticamente pelo Skyward.

O que é o Skyward?

Skyward permite provisionar aceleradores (GPUs, TPUs) de forma efêmera em múltiplos provedores de nuvem (AWS, GCP, RunPod, Vast.ai) e executar código remotamente com uma API Python simples. Ao final da execução, a infraestrutura é destruída automaticamente.

Principais características:

  • Multi-provider: troque de provedor com um único argumento
  • Plugins: suporte nativo a PyTorch, JAX, Keras 3, Accelerate, cuML
  • Operadores intuitivos: >> para execução single-node, @ para broadcast em múltiplos nós
  • Spot instances: gerenciamento automático de preempção e realocação

Preparando o ambiente

pip install skyward

Definindo a função de treinamento

O Skyward usa o decorator @sky.function para marcar funções que serão executadas remotamente. Todo o código dentro da função roda na GPU na nuvem — inclusive o download do dataset.

import skyward as sky
from keras import layers, ops
import keras


@sky.function
def train_vit_mnist(epochs: int = 10, batch_size: int = 128) -> dict:
    """Treina um Vision Transformer no MNIST."""

    # Carregar e pré-processar MNIST
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Reshape para 28x28x1 (imagem com canal)
    x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
    x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

    # Sharding dos dados entre os nós
    x_train, y_train = sky.shard(x_train, y_train, shuffle=True, seed=42)

    # Patch embedding: dividir cada imagem 28x28 em patches 7x7
    patch_size = 7
    num_patches = (28 // patch_size) ** 2  # 16 patches
    patch_dim = patch_size * patch_size * 1  # 49

    # Modelo ViT simplificado
    inputs = keras.Input(shape=(28, 28, 1))

    # Extrair patches
    patches = layers.Reshape((num_patches, patch_dim))(
        layers.Reshape((4, 7, 4, 7, 1))(inputs)
        # Reorganizar via permute para (4, 4, 7, 7, 1) -> (16, 49)
    )

    # Projeção linear dos patches + positional embedding
    x = layers.Dense(64)(patches)
    x = x + layers.Embedding(num_patches, 64)(
        ops.arange(num_patches)
    )

    # Transformer blocks
    for _ in range(4):
        # Multi-head self-attention
        attn_output = layers.MultiHeadAttention(
            num_heads=4, key_dim=16
        )(x, x)
        x = layers.LayerNormalization()(x + attn_output)

        # Feed-forward
        ff = layers.Dense(128, activation="gelu")(x)
        ff = layers.Dense(64)(ff)
        x = layers.LayerNormalization()(x + ff)

    # Global average pooling + classificação
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(10, activation="softmax")(x)

    model = keras.Model(inputs, outputs)

    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    history = model.fit(
        x_train, y_train,
        epochs=epochs,
        batch_size=batch_size,
        validation_split=0.1,
        verbose=1,
    )

    _, test_acc = model.evaluate(x_test, y_test, verbose=0)

    return {
        "test_accuracy": float(test_acc),
        "final_train_accuracy": float(history.history["accuracy"][-1]),
        "epochs": epochs,
    }

Executando na nuvem

Com a função definida, basta abrir um contexto sky.Compute para provisionar as GPUs e executar:

import asyncio

async def main():
    async with sky.Compute(
        provider=sky.VastAI(),
        accelerator=sky.accelerators.RTX4090(),
        nodes=2,
        plugins=[sky.plugins.jax(), sky.plugins.keras(backend="jax")],
    ) as compute:
        results = train_vit_mnist(epochs=10) @ compute

        for i, result in enumerate(results):
            print(f"Nó {i}: acurácia no teste = {result['test_accuracy']:.4f}")

asyncio.run(main())

O operador @ faz broadcast — executa a função em todos os nós, cada um treinando com seu shard dos dados. O sky.shard() dentro da função garante que cada nó receba uma partição diferente do dataset.

Trocando de provedor

Uma das maiores vantagens do Skyward é a portabilidade. Para rodar o mesmo código na AWS ao invés da Vast.ai:

async with sky.Compute(
    provider=sky.AWS(),
    accelerator=sky.accelerators.T4(),
    nodes=2,
    plugins=[sky.plugins.jax(), sky.plugins.keras(backend="jax")],
) as compute:
    results = train_vit_mnist(epochs=10) @ compute

Apenas o provider e o accelerator mudam. O resto do código permanece idêntico.

Resultados

Com 2 nós RTX 4090 na Vast.ai, o ViT simplificado atingiu ~98.5% de acurácia no MNIST em 10 épocas, com tempo total de treinamento (incluindo provisionamento) de aproximadamente 4 minutos.

ConfiguraçãoAcuráciaTempo total
1x T4 (AWS)98.1%~6 min
2x RTX 4090 (Vast.ai)98.5%~4 min
2x T4 (AWS)98.3%~5 min

Próximos passos

  • Escalar para datasets maiores (ImageNet, CIFAR-100) com volumes S3
  • Experimentar FSDP para modelos ViT maiores
  • Comparar custos entre provedores usando sky.compare()

O código completo está disponível no repositório do Skyward.