derivkit.derivatives.autodiff.jax_core module#

JAX-based autodiff helpers for DerivKit.

This module does not register any DerivKit backend by default.

Use these functions directly, or see derivkit.autodiff.jax_autodiff.register_jax_autodiff_backend() for an opt-in integration.

Use only with JAX-differentiable functions. For arbitrary models, prefer the “adaptive” or “finite” methods.

Shape conventions (aligned with derivKit.calculus builders):

  • autodiff_derivative: \(f:\\mathbb{R}\\mapsto\\mathbb{R}\) → returns float (scalar)

  • autodiff_gradient: \(f:\\mathbb{R}^n\\mapsto\\mathbb{R}\) → returns array of shape (n,)

  • autodiff_jacobian: \(f:\\mathbb{R}^n\\mapsto\\mathbb{R}^m\) (or tensor output) → returns array of shape (m, n), where m = \\prod(\text{out\\_shape})

  • autodiff_hessian: \(f:\\mathbb{R}^n\\mapsto\\mathbb{R}\) → returns array of shape (n, n)

derivkit.derivatives.autodiff.jax_core.autodiff_derivative(func: Callable, x0: float, order: int = 1) float#

Calculates the k-th derivative of a function f: R -> R via JAX autodiff.

Parameters:
  • func – Callable mapping float -> scalar.

  • x0 – Point at which to evaluate the derivative.

  • order – Derivative order (>=1); uses repeated grad for higher orders.

Returns:

Derivative value as a float.

Raises:
  • AutodiffUnavailable – If JAX is not available or function is not differentiable.

  • ValueError – If order < 1.

  • TypeError – If func(x) is not scalar-valued.

derivkit.derivatives.autodiff.jax_core.autodiff_gradient(func: Callable, x0) ndarray#

Computes the gradient of a scalar-valued function f: R^n -> R via JAX autodiff.

Parameters:
  • func – Function to be differentiated.

  • x0 – Point at which to evaluate the gradient.

Returns:

A gradient vector as a 1D numpy.ndarray with shape (n,).

Raises:
  • AutodiffUnavailable – If JAX is not available or function is not differentiable.

  • TypeError – If func(theta) is not scalar-valued.

derivkit.derivatives.autodiff.jax_core.autodiff_hessian(func: Callable, x0) ndarray#

Calculates the full Hessian of a scalar-valued function.

Parameters:
  • func – A function to be differentiated.

  • x0 – Point at which to evaluate the Hessian; array-like, shape (n,) with n = input dimension.

Returns:

A Hessian matrix as a 2D numpy.ndarray with shape (n, n).

Raises:
  • AutodiffUnavailable – If JAX is not available or function is not differentiable.

  • TypeError – If func(theta) is not scalar-valued.

derivkit.derivatives.autodiff.jax_core.autodiff_jacobian(func: Callable, x0, *, mode: str | None = None) ndarray#

Calculates the Jacobian of a vector-valued function via JAX autodiff.

Output convention matches DerivKit Jacobian builders: we flatten the function output to length m = prod(out_shape), and return a 2D Jacobian of shape (m, n), with n = input dimension.

Parameters:
  • func – Function to be differentiated.

  • x0 – Point at which to evaluate the Jacobian; array-like, shape (n,).

  • mode – Differentiation mode; None (auto), ‘fwd’, or ‘rev’. If None, chooses ‘rev’ if m <= n, else ‘fwd’. For more details about modes, see JAX documentation for jax.jacrev and jax.jacfwd.

Returns:

A Jacobian matrix as a 2D numpy.ndarray with shape (m, n).

Raises:
  • AutodiffUnavailable – If JAX is not available or function is not differentiable.

  • ValueError – If mode is invalid.

  • TypeError – If func(theta) is scalar-valued.