What are plugins?¶
Skyward's plugin system is the way you bring third-party frameworks into the compute pool. When you pass plugins=[sky.plugins.torch()] to a Compute pool, you are telling Skyward: install PyTorch on the remote workers, configure the distributed runtime before my function runs, and clean up when the worker stops. The plugin handles the environment setup, the lifecycle hooks, and the per-task wrapping — things you would otherwise do manually with Image(pip=[...]), environment variables, and boilerplate inside your @sky.function functions.
The key insight is that plugins operate at the pool level, not at the function level. A single plugin declaration on the pool affects every task dispatched to it. This is different from the decorator pattern you might be used to, where each function explicitly opts in to framework setup. With plugins, the pool is the unit of configuration: once you declare that a pool uses PyTorch with NCCL, every function dispatched to that pool gets PyTorch's distributed environment configured automatically.
The Plugin dataclass¶
A Plugin is a frozen dataclass with six optional hooks. Each hook corresponds to a different phase in the pool and worker lifecycle. You do not need to implement all six — most plugins use two or three.
@dataclass(frozen=True, slots=True)
class Plugin:
name: str
transform: ImageTransform | None = None
bootstrap: BootstrapFactory | None = None
decorate: TaskDecorator | None = None
around_app: AppLifecycle | None = None
around_process: ProcessLifecycle | None = None
around_client: ClientLifecycle | None = None
The hooks are:
transform modifies the Image before bootstrap. It receives the current Image and the Cluster metadata, and returns a new Image with additional pip packages, pip indexes, environment variables, or apt packages. This is how plugins install their dependencies on the remote worker. For example, the torch plugin appends "torch" to pip and adds PyTorch's CUDA wheel index. The keras plugin appends "keras" and sets KERAS_BACKEND in the environment. Since Image is a frozen dataclass, the transform returns a new copy via replace() — it never mutates the original.
bootstrap injects shell operations after the standard bootstrap phases (apt, pip, etc.). It receives the Cluster and returns a tuple of shell ops. The huggingface plugin uses this to run huggingface-cli login after pip packages are installed, so the worker is authenticated before any task runs. The mps plugin uses it to start the NVIDIA MPS daemon.
decorate wraps each @sky.function function at execution time on the remote worker. It is a classic Python decorator: it takes a function and returns a function. This is for per-task logic that must run every time a function executes — things like logging, metrics collection, or framework-specific wrappers that depend on each call's arguments.
around_app is a context manager that runs once in the main worker process. It receives an InstanceInfo and returns a context manager. The context is entered at worker startup and stays active for the lifetime of the worker. This is designed for one-time, process-wide initialization — things that must happen exactly once and persist.
The state module (skyward.plugins.state) tracks which around_app hooks have been entered. It stores the context managers in a module-level dictionary and checks before entering — if the key already exists, it is a no-op. This makes the hook idempotent: even if multiple tasks execute on the same worker, each around_app is entered exactly once.
around_process is a context manager that runs once per executor subprocess. It receives an InstanceInfo and returns a context manager. This hook is lazy — it enters on the first task execution in each subprocess, after environment variables are propagated. Only relevant when executor="process". The torch plugin uses this for dist.init_process_group(), the jax plugin for jax.distributed.initialize(), the keras plugin for DataParallel distribution setup, and the cuml plugin for cuml.accel.install(). All are irreversible, process-global operations that should not be repeated per task.
The process state module (skyward.plugins.process_state) tracks which around_process hooks have been entered, with the same idempotency guarantees as around_app.
around_client is a context manager that runs on the client side, not the worker. It receives the Compute pool and the Cluster, and wraps the pool's entire active lifetime. The joblib and sklearn plugins use this to register the SkywardBackend as joblib's parallel backend, so that any Parallel(n_jobs=-1) call inside the with block dispatches work to the cluster instead of local processes.
Builder API¶
You can construct plugins using the builder pattern instead of passing all hooks to the constructor:
plugin = (
Plugin.create("my-plugin")
.with_image_transform(lambda img, cluster: replace(img, pip=(*img.pip, "my-lib")))
.with_decorator(my_decorator)
.with_around_app(my_lifecycle)
)
Each .with_* method returns a new Plugin instance (immutable — uses replace()). This is how the built-in plugins are implemented internally: the factory function (e.g., sky.plugins.torch()) defines the hooks as closures and chains them together with the builder.
How hooks execute¶
The hooks run at different points in the pool lifecycle, and the order matters.
When the pool starts (Compute.__enter__):
transformhooks run first, in plugin order. Each transform receives the image returned by the previous one. The final image is used to generate the bootstrap script.bootstraphooks run after the standard bootstrap phases complete on each worker. The ops are appended in plugin order.around_clienthooks are entered on the client, in plugin order.
When a task executes on a worker:
around_apphooks are entered at worker startup (idempotent — skipped if already active).around_processhooks are lazily entered on first task execution in each executor subprocess (idempotent — skipped if already active). Only relevant whenexecutor="process".decoratehooks wrap the function. If multiple plugins have decorators, they are chained: the first plugin's decorator is outermost, the last is innermost. The chaining usesfunctools.reduceoverreversed(decorators), so the first plugin listed inplugins=[...]runs first and the last runs last.
When the pool stops (Compute.__exit__):
around_clientcontexts are exited in reverse order.around_processcontexts are exited in reverse order when executor subprocesses shut down.around_appcontexts are exited in reverse order when the worker process shuts down.
Plugin composition¶
Plugins compose naturally because each hook is independent. You can stack multiple plugins and their effects combine:
with sky.Compute(
provider=sky.AWS(),
accelerator=sky.accelerators.A100(),
nodes=4,
plugins=[
sky.plugins.torch(backend="nccl"),
sky.plugins.huggingface(token="hf_xxx"),
],
) as compute:
train() >> compute
The torch plugin adds PyTorch to pip and initializes DDP via around_process. The huggingface plugin adds transformers, datasets, and tokenizers to pip, sets HF_TOKEN, and runs huggingface-cli login. Their image transforms compose (PyTorch packages + HuggingFace packages), and their around_process hooks are entered independently in plugin order.
Order can matter. When using Keras with JAX, the JAX plugin should come first because its around_process initializes the distributed runtime that Keras depends on:
The JAX plugin's around_process calls jax.distributed.initialize(), and Keras's around_process calls keras.distribution.set_distribution(DataParallel(...)). The distribution setup needs JAX's device mesh to already be visible, so JAX must initialize first. Since around_process hooks are entered in plugin order, listing JAX first ensures the correct sequence.
Built-in plugins¶
Skyward ships with nine plugins:
| Plugin | Primary Hooks | Purpose |
|---|---|---|
accelerate |
transform, around_process |
Distributed training with FSDP, DeepSpeed, and mixed precision via Hugging Face Accelerate |
torch |
transform, around_process |
PyTorch installation and DDP initialization |
jax |
transform, around_process |
JAX installation and distributed initialization |
keras |
transform, around_process |
Keras backend configuration and DataParallel |
huggingface |
transform, bootstrap |
Transformers, datasets, tokenizers, and auth |
joblib |
transform, around_client |
Distributed joblib parallel backend |
sklearn |
transform, around_client |
Scikit-learn with distributed joblib |
cuml |
transform, around_process |
GPU-accelerated scikit-learn via RAPIDS cuML |
mps |
transform, bootstrap |
NVIDIA Multi-Process Service for GPU sharing |
Custom plugins¶
Building a custom plugin follows the same pattern as the built-in ones. Define your hooks as functions, then chain them with the builder:
from dataclasses import replace
from skyward.plugins import Plugin
def my_framework() -> Plugin:
def transform(image, cluster):
return replace(image, pip=(*image.pip, "my-framework"))
def decorate(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
setup_my_framework()
return fn(*args, **kwargs)
return wrapper
return (
Plugin.create("my-framework")
.with_image_transform(transform)
.with_decorator(decorate)
)
Use it like any built-in plugin:
Next steps¶
- PyTorch — DDP initialization and CUDA wheel management
- JAX — Distributed initialization with
around_process - Keras — Backend-agnostic training with DataParallel
- Distributed Training — How plugins fit into multi-node training
- Getting Started — First steps with Skyward