jax.experimental.pallas.MemoryRef#

class jax.experimental.pallas.MemoryRef(shape, dtype, memory_space)[source]#

Like jax.ShapeDtypeStruct but with memory spaces.

Parameters:
  • shape (tuple[int, ...])

  • dtype (jnp.dtype | dtypes.ExtendedDType)

  • memory_space (Any)

__init__(shape, dtype, memory_space)#
Parameters:
  • shape (tuple[int, ...])

  • dtype (jnp.dtype | dtypes.ExtendedDType)

  • memory_space (Any)

Return type:

None

Methods

__init__(shape, dtype, memory_space)

get_array_aval()

get_ref_aval()

Attributes

shape

dtype

memory_space