Keras training¶
Keras 3 is backend-agnostic — the same model code runs on JAX, TensorFlow, or PyTorch. Skyward's keras plugin configures the backend on the remote worker before your function runs, and shard() handles data partitioning for multi-node training. This guide walks through training an MLP on MNIST across multiple cloud GPUs using Keras with JAX as the backend.
The keras plugin¶
Add sky.plugins.keras(backend="jax") to your pool's plugins. When using the JAX backend, also include sky.plugins.jax() for distributed initialization:
with sky.Compute(
provider=sky.AWS(),
accelerator=sky.accelerators.T4(),
nodes=2,
plugins=[sky.plugins.jax(), sky.plugins.keras(backend="jax")],
) as compute:
The backend parameter sets KERAS_BACKEND on the remote worker before Keras is imported. This is critical — Keras reads the backend at import time, so the environment variable must be set first.
The function itself just uses @sky.function — the backend and distributed setup are handled by the plugins:
@sky.function
def train_mnist(epochs: int = 5, batch_size: int = 128) -> dict:
"""Train an MLP on this node's shard of MNIST."""
Loading and sharding data¶
Load the full dataset inside the function, then use shard() to get this node's portion:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
x_train, y_train = sky.shard(x_train, y_train, shuffle=True, seed=42)
keras.datasets.mnist.load_data() downloads the dataset on the remote worker. shard() then splits the training data so each node trains on a different subset — with 2 nodes, each gets half. The shuffle=True and seed=42 parameters ensure a deterministic, randomized split so both nodes agree on who gets which samples.
Note that sharding happens inside the function, after the data is loaded. The full dataset exists on every node (each one downloads it independently), and sharding selects each node's portion based on instance_info(). This is simpler than pre-splitting and distributing data from the client.
Model definition¶
Define a standard Keras model — nothing Skyward-specific here:
model = keras.Sequential([
layers.Dense(256, activation="relu", input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(128, activation="relu"),
layers.Dropout(0.2),
layers.Dense(10, activation="softmax"),
])
This is the same Keras Sequential API you'd use locally. The model runs on whatever backend the plugin configured — JAX in this case. If you switch to backend="torch", the same model definition produces a PyTorch-backed model.
Training¶
Compile and fit as usual:
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, verbose=1)
_, test_acc = model.evaluate(x_test, y_test, verbose=0)
model.fit() runs on the remote GPU. Each node trains independently on its shard of the data, so training time scales inversely with the number of nodes (minus overhead). The evaluation runs on the full test set — each node evaluates independently and reports its own accuracy.
For synchronized multi-node training with gradient averaging (similar to PyTorch DDP), Keras provides distribution strategies. The keras plugin can configure these automatically when running with JAX on multiple nodes. For data-parallel training where each node trains independently on its shard (as in this example), no extra configuration is needed.
Run the full example¶
git clone https://github.com/gabfssilva/skyward.git
cd skyward
uv run python guides/07_keras_training.py
What you learned:
plugins=[sky.plugins.jax(), sky.plugins.keras(backend="jax")]sets the Keras backend and configures JAX distributed on the remote worker.shard()splits training data across nodes — each node trains on its own subset.- Standard Keras API —
Sequential,model.compile(),model.fit()work unchanged. - Backend-agnostic — switch between JAX, TensorFlow, and PyTorch with one parameter.
- Data loads on the worker — no need to transfer datasets from your local machine.