derivkit.derivatives.autodiff.jax_utils module#
Utilities for JAX-based autodiff in DerivKit.
Bases:
RuntimeErrorRaises when JAX-based autodiff is unavailable.
- derivkit.derivatives.autodiff.jax_utils.apply_array_nd(func: Callable, where: str, theta: jnp.ndarray) jnp.ndarray#
Takes an input function and maps it over an ND array with array output enforcement.
- Parameters:
func – Function to apply.
where – Context string for error messages.
theta – ND JAX array of inputs.
- Returns:
JAX array of array outputs.
- derivkit.derivatives.autodiff.jax_utils.apply_scalar_1d(func: Callable[[float], Any], where: str, x: jnp.ndarray) jnp.ndarray#
Takes an input function and maps it over a 1D array with scalar output enforcement.
- Parameters:
func – Function to apply.
where – Context string for error messages.
x – 1D JAX array of inputs.
- Returns:
JAX array of scalar outputs.
- derivkit.derivatives.autodiff.jax_utils.apply_scalar_nd(func: Callable, where: str, theta: jnp.ndarray) jnp.ndarray#
Takes an input function and maps it over an ND array with scalar output enforcement.
- Parameters:
func – Function to apply.
where – Context string for error messages.
theta – ND JAX array of inputs.
- Returns:
JAX array of scalar outputs.
- derivkit.derivatives.autodiff.jax_utils.require_jax() None#
Raises if JAX is not available.
- Parameters:
None.
- Returns:
None.
- Raises:
AutodiffUnavailable – If JAX is not installed.
- derivkit.derivatives.autodiff.jax_utils.to_jax_array(y: Any, *, where: str) jnp.ndarray#
Ensures that output is array-like (not scalar) and returns as JAX array.
- Parameters:
y – Output to check.
where – Context string for error messages.
- Returns:
Non-scalar JAX array with shape (m,) or higher.
- Raises:
TypeError – If output is scalar or cannot be converted to JAX array.
- derivkit.derivatives.autodiff.jax_utils.to_jax_scalar(y: Any, *, where: str) jnp.ndarray#
Ensures that output is scalar and returns as JAX array.
- Parameters:
y – Output to check.
where – Context string for error messages.
- Returns:
Scalar (0-d) JAX array with shape ().
- Raises:
TypeError – If output is not scalar.