derivkit.derivatives.autodiff.jax_core module#
JAX-based autodiff helpers for DerivKit.
This module does not register any DerivKit backend by default.
Use these functions directly, or see
derivkit.autodiff.jax_autodiff.register_jax_autodiff_backend()
for an opt-in integration.
Use only with JAX-differentiable functions. For arbitrary models, prefer the “adaptive” or “finite” methods.
Shape conventions (aligned with derivKit.calculus builders):
autodiff_derivative: \(f:\\mathbb{R}\\mapsto\\mathbb{R}\) → returnsfloat(scalar)autodiff_gradient: \(f:\\mathbb{R}^n\\mapsto\\mathbb{R}\) → returns array of shape(n,)autodiff_jacobian: \(f:\\mathbb{R}^n\\mapsto\\mathbb{R}^m\) (or tensor output) → returns array of shape(m, n), wherem = \\prod(\text{out\\_shape})autodiff_hessian: \(f:\\mathbb{R}^n\\mapsto\\mathbb{R}\) → returns array of shape(n, n)
- derivkit.derivatives.autodiff.jax_core.autodiff_derivative(func: Callable, x0: float, order: int = 1) float#
Calculates the k-th derivative of a function f: R -> R via JAX autodiff.
- Parameters:
func – Callable mapping float -> scalar.
x0 – Point at which to evaluate the derivative.
order – Derivative order (>=1); uses repeated grad for higher orders.
- Returns:
Derivative value as a float.
- Raises:
AutodiffUnavailable – If JAX is not available or function is not differentiable.
ValueError – If order < 1.
TypeError – If func(x) is not scalar-valued.
- derivkit.derivatives.autodiff.jax_core.autodiff_gradient(func: Callable, x0) ndarray#
Computes the gradient of a scalar-valued function f: R^n -> R via JAX autodiff.
- Parameters:
func – Function to be differentiated.
x0 – Point at which to evaluate the gradient.
- Returns:
A gradient vector as a 1D numpy.ndarray with shape (n,).
- Raises:
AutodiffUnavailable – If JAX is not available or function is not differentiable.
TypeError – If func(theta) is not scalar-valued.
- derivkit.derivatives.autodiff.jax_core.autodiff_hessian(func: Callable, x0) ndarray#
Calculates the full Hessian of a scalar-valued function.
- Parameters:
func – A function to be differentiated.
x0 – Point at which to evaluate the Hessian; array-like, shape (n,) with n = input dimension.
- Returns:
A Hessian matrix as a 2D numpy.ndarray with shape (n, n).
- Raises:
AutodiffUnavailable – If JAX is not available or function is not differentiable.
TypeError – If func(theta) is not scalar-valued.
- derivkit.derivatives.autodiff.jax_core.autodiff_jacobian(func: Callable, x0, *, mode: str | None = None) ndarray#
Calculates the Jacobian of a vector-valued function via JAX autodiff.
Output convention matches DerivKit Jacobian builders: we flatten the function output to length m = prod(out_shape), and return a 2D Jacobian of shape (m, n), with n = input dimension.
- Parameters:
func – Function to be differentiated.
x0 – Point at which to evaluate the Jacobian; array-like, shape (n,).
mode – Differentiation mode; None (auto), ‘fwd’, or ‘rev’. If None, chooses ‘rev’ if m <= n, else ‘fwd’. For more details about modes, see JAX documentation for jax.jacrev and jax.jacfwd.
- Returns:
A Jacobian matrix as a 2D numpy.ndarray with shape (m, n).
- Raises:
AutodiffUnavailable – If JAX is not available or function is not differentiable.
ValueError – If mode is invalid.
TypeError – If func(theta) is scalar-valued.