derivkit.derivatives.autodiff.jax_utils module#

Utilities for JAX-based autodiff in DerivKit.

exception derivkit.derivatives.autodiff.jax_utils.AutodiffUnavailable#

Bases: RuntimeError

Raises when JAX-based autodiff is unavailable.

derivkit.derivatives.autodiff.jax_utils.apply_array_nd(func: Callable, where: str, theta: jnp.ndarray) jnp.ndarray#

Takes an input function and maps it over an ND array with array output enforcement.

Parameters:
  • func – Function to apply.

  • where – Context string for error messages.

  • theta – ND JAX array of inputs.

Returns:

JAX array of array outputs.

derivkit.derivatives.autodiff.jax_utils.apply_scalar_1d(func: Callable[[float], Any], where: str, x: jnp.ndarray) jnp.ndarray#

Takes an input function and maps it over a 1D array with scalar output enforcement.

Parameters:
  • func – Function to apply.

  • where – Context string for error messages.

  • x – 1D JAX array of inputs.

Returns:

JAX array of scalar outputs.

derivkit.derivatives.autodiff.jax_utils.apply_scalar_nd(func: Callable, where: str, theta: jnp.ndarray) jnp.ndarray#

Takes an input function and maps it over an ND array with scalar output enforcement.

Parameters:
  • func – Function to apply.

  • where – Context string for error messages.

  • theta – ND JAX array of inputs.

Returns:

JAX array of scalar outputs.

derivkit.derivatives.autodiff.jax_utils.require_jax() None#

Raises if JAX is not available.

Parameters:

None.

Returns:

None.

Raises:

AutodiffUnavailable – If JAX is not installed.

derivkit.derivatives.autodiff.jax_utils.to_jax_array(y: Any, *, where: str) jnp.ndarray#

Ensures that output is array-like (not scalar) and returns as JAX array.

Parameters:
  • y – Output to check.

  • where – Context string for error messages.

Returns:

Non-scalar JAX array with shape (m,) or higher.

Raises:

TypeError – If output is scalar or cannot be converted to JAX array.

derivkit.derivatives.autodiff.jax_utils.to_jax_scalar(y: Any, *, where: str) jnp.ndarray#

Ensures that output is scalar and returns as JAX array.

Parameters:
  • y – Output to check.

  • where – Context string for error messages.

Returns:

Scalar (0-d) JAX array with shape ().

Raises:

TypeError – If output is not scalar.