Skip to content

Runtime

skyward.InstanceInfo

Bases: BaseModel

Cluster topology and node metadata for the current worker.

Parsed from the COMPUTE_POOL environment variable that Skyward injects into every remote worker process. Provides node indices, peer addresses, accelerator info, and convenience properties for common distributed patterns.

Examples:

>>> @sky.function
... def distributed_task():
...     info = sky.instance_info()
...     if info.is_head:
...         print(f"Head node of {info.total_nodes} nodes")
...     local_data = sky.shard(dataset)
...     return train(local_data)

node = Field(description='Index of this node (0 to total_nodes - 1)') class-attribute instance-attribute

worker = Field(default=0, description='Worker index within this node (0 to workers_per_node - 1)') class-attribute instance-attribute

total_nodes = Field(description='Total number of nodes in the pool') class-attribute instance-attribute

workers_per_node = Field(default=1, description='Number of workers per node (e.g., 2 for MIG 3g.40gb)') class-attribute instance-attribute

accelerators = Field(description='Number of accelerators on this node') class-attribute instance-attribute

total_accelerators = Field(description='Total accelerators in the pool') class-attribute instance-attribute

head_addr = Field(description='IP address of the head node') class-attribute instance-attribute

head_port = Field(description='Port for head node coordination') class-attribute instance-attribute

job_id = Field(description='Unique identifier for this pool execution') class-attribute instance-attribute

peers = Field(description='Information about all peers') class-attribute instance-attribute

accelerator = Field(default=None, description='Accelerator configuration') class-attribute instance-attribute

network = Field(description='Network configuration') class-attribute instance-attribute

global_worker_index property

Global index of this worker (0 to total_workers - 1).

total_workers property

Total number of workers across all nodes.

is_head property

True if this is the head worker (global_worker_index == 0).

hostname property

Current instance hostname.

current() classmethod

Get pool info from COMPUTE_POOL environment variable.

skyward.instance_info()

Return information about the current compute instance.

Must be called from within a @sky.function running on a remote node.

Returns:

Type Description
InstanceInfo | None

Cluster topology and node metadata, or None if not in a pool.

Examples:

>>> @sky.function
... def distributed_task(data):
...     info = sky.instance_info()
...     if info.is_head:
...         print(f"Head node of {info.total_nodes} nodes")
...     return process(data)

skyward.shard(*data, shuffle=False, seed=0, drop_last=False, node=None, total_nodes=None)

shard(
    data: list[T],
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> list[T]
shard(
    data: tuple[T, ...],
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> tuple[T, ...]
shard(
    data: npt.NDArray[Any],
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> npt.NDArray[Any]
shard(
    data: torch.Tensor,
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> torch.Tensor
shard(
    data1: T1,
    data2: T2,
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> tuple[T1, T2]
shard(
    data1: T1,
    data2: T2,
    data3: T3,
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> tuple[T1, T2, T3]
shard(
    data1: T1,
    data2: T2,
    data3: T3,
    data4: T4,
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> tuple[T1, T2, T3, T4]
shard(
    data1: T1,
    data2: T2,
    data3: T3,
    data4: T4,
    data5: T5,
    /,
    *,
    shuffle: bool = False,
    seed: int = 0,
    drop_last: bool = False,
    node: int | None = None,
    total_nodes: int | None = None,
) -> tuple[T1, T2, T3, T4, T5]

Shard data across distributed nodes, preserving input type.

Return ONLY this node's portion of the data. Supports list, tuple, np.ndarray, torch.Tensor, and any Sequence.

Can accept multiple arrays at once — they will all be sharded with the same indices (useful for keeping x and y aligned).

Parameters:

Name Type Description Default
*data Any

One or more arrays/sequences to shard.

()
shuffle bool

Shuffle with synchronized seed across all nodes.

False
seed int

Random seed for reproducible shuffling.

0
drop_last bool

Drop tail items so all nodes get equal count.

False
node int | None

Override node index (for testing).

None
total_nodes int | None

Override total_nodes (for testing).

None

Returns:

Type Description
T | tuple[T, ...]

If single argument: this node's shard with same type as input. If multiple arguments: tuple of shards.

Examples:

>>> my_data = shard(full_dataset, shuffle=True, seed=42)
>>> x_train, y_train = shard(x_train, y_train)
>>> x_train, y_train, x_test, y_test = shard(x_train, y_train, x_test, y_test)

skyward.stdout(only)

Control stdout emission in distributed execution.

Silence stdout for workers that don't match the predicate. stderr is NOT affected — errors from any worker are always visible.

Parameters:

Name Type Description Default
only OutputSpec

Predicate or "head" shortcut. Workers matching this emit stdout.

  • "head" — only head worker (node == 0).
  • Callable[[InstanceInfo], bool] — custom predicate.
required

Returns:

Type Description
Callable

Decorator that wraps the function with stdout control.

Examples:

>>> @sky.stdout(only="head")
... @sky.function
... def train(data):
...     print("Only head node prints this")
...     return model.fit(data)

skyward.stderr(only)

Control stderr emission in distributed execution.

Silence stderr for workers that don't match the predicate. Use with caution — silencing errors can hide problems.

Parameters:

Name Type Description Default
only OutputSpec

Predicate or "head" shortcut. Workers matching this emit stderr.

required

Returns:

Type Description
Callable

Decorator that wraps the function with stderr control.

skyward.silent(fn)

Silence both stdout and stderr completely.

Useful for functions that should never emit output regardless of rank.

Examples:

>>> @sky.silent
... @sky.function
... def quiet_task(data):
...     return process(data)

skyward.is_head(info)

True if this is the head worker (node == 0).

skyward.CallbackWriter

Bases: TextIO

Write-only stream adapter that forwards writes to a callback.

Implement the TextIO interface so it can replace sys.stdout or sys.stderr via contextlib.redirect_stdout.

Parameters:

Name Type Description Default
callback Callable[[str], None]

Called with each string written to the stream.

required

__init__(callback)

write(s)

getvalue()

read(n=-1)

readline(limit=-1)

flush()

close()

seekable()

readable()

writable()

skyward.redirect_output(callback)

Redirect stdout and stderr to a callback within a context.

Parameters:

Name Type Description Default
callback Callable[[str], None]

Called with each string written to stdout or stderr.

required

Yields:

Type Description
tuple[CallbackWriter, CallbackWriter]

The (stdout_writer, stderr_writer) pair.

Examples:

>>> lines = []
>>> with sky.redirect_output(lines.append):
...     print("captured")
>>> assert "captured\n" in lines