jax.scipy.spatial.transform.Slerp#
- class jax.scipy.spatial.transform.Slerp(times, timedelta, rotations, rotvecs)[source]#
Spherical Linear Interpolation of Rotations.
JAX implementation of
scipy.spatial.transform.Slerp
.Examples
Create a Slerp instance from a series of rotations:
>>> import math >>> from jax.scipy.spatial.transform import Rotation, Slerp >>> rots = jnp.array([[90, 0, 0], ... [0, 45, 0], ... [0, 0, -30]]) >>> key_rotations = Rotation.from_euler('zxy', rots, degrees=True) >>> key_times = [0, 1, 2] >>> slerp = Slerp.init(key_times, key_rotations) >>> times = [0, 0.5, 1, 1.5, 2] >>> interp_rots = slerp(times) >>> interp_rots.as_euler('zxy') Array([[ 1.5707963e+00, 0.0000000e+00, 0.0000000e+00], [ 8.5309029e-01, 3.8711953e-01, 1.7768645e-01], [-2.3841858e-07, 7.8539824e-01, 0.0000000e+00], [-5.6668043e-02, 3.9213133e-01, -2.8347540e-01], [ 0.0000000e+00, 0.0000000e+00, -5.2359891e-01]], dtype=float32)
- Parameters:
times (jnp.ndarray)
timedelta (jnp.ndarray)
rotations (Rotation)
rotvecs (jnp.ndarray)
- __init__()#
Methods
__init__
()count
(value, /)Return number of occurrences of value.
index
(value[, start, stop])Return first index of value.
init
(times, rotations)Attributes
rotations
Alias for field number 2
rotvecs
Alias for field number 3
timedelta
Alias for field number 1
times
Alias for field number 0