jax.experimental.pallas.Slice#

class jax.experimental.pallas.Slice(start, size, stride=1)[source]#

A slice with a start index and a size.

Both start index and size can either be static, i.e. known at tracing and compilation time, or dynamic.

Parameters:
__init__(start, size, stride=1)#
Parameters:
Return type:

None

Methods

__init__(start, size[, stride])

from_slice(slc, size)

tree_flatten()

tree_unflatten(aux_data, children)

Attributes

is_dynamic_size

is_dynamic_start

stride

start

size