jax.ShapeDtypeStruct#

class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False)[source]#

A container for the shape, dtype, and other static attributes of an array.

ShapeDtypeStruct is often used in conjunction with jax.eval_shape().

Parameters:
  • shape – a sequence of integers representing an array shape

  • dtype – a dtype-like object

  • sharding – (optional) a jax.Sharding object

__init__(shape, dtype, *, sharding=None, weak_type=False)[source]#

Methods

__init__(shape, dtype, *[, sharding, weak_type])

update(**kwargs)

Attributes

shape

dtype

sharding

weak_type

layout

ndim

size