jax.extend.ffi.pycapsule#

jax.extend.ffi.pycapsule(funcptr)[source]#

Wrap a ctypes function pointer in a PyCapsule.

The primary use of this function, and the reason why it lives with in the jax.ffi submodule, is to wrap function calls from external compiled libraries to be registered as XLA custom calls.

Example usage:

import ctypes
import jax
from jax.lib import xla_client

libfoo = ctypes.cdll.LoadLibrary('./foo.so')
xla_client.register_custom_call_target(
    name="bar",
    fn=jax.ffi.pycapsule(libfoo.bar),
    platform=PLATFORM,
    api_version=API_VERSION
)
Parameters:

funcptr – A function pointer loaded from a dynamic library using ctypes.

Returns:

An opaque PyCapsule object wrapping funcptr.