jax.experimental.pallas.MemoryRef# class jax.experimental.pallas.MemoryRef(shape, dtype, memory_space)[source]# Like jax.ShapeDtypeStruct but with memory spaces. Parameters: shape (tuple[int, ...]) dtype (jnp.dtype | dtypes.ExtendedDType) memory_space (Any) __init__(shape, dtype, memory_space)# Parameters: shape (tuple[int, ...]) dtype (jnp.dtype | dtypes.ExtendedDType) memory_space (Any) Return type: None Methods __init__(shape, dtype, memory_space) get_array_aval() get_ref_aval() Attributes shape dtype memory_space