derivkit.derivatives.autodiff.jax_autodiff module#
JAX-based autodiff backend for DerivativeKit.
This backend is intentionally minimal: it only supports scalar derivatives $f: R mapsto R$ via JAX autodiff, and must be registered explicitly as explained in the example.
Example:#
Basic usage (opt-in registration):
>>> from derivkit.derivative_kit import DerivativeKit
>>> from derivkit.derivatives.autodiff.jax_autodiff import register_jax_autodiff_backend
>>> register_jax_autodiff_backend()
>>>
>>> def func(x):
... import jax.numpy as jnp
... return jnp.sin(x) + 0.5 * x**2
...
>>> dk = DerivativeKit(func, 1.0)
>>> dk.differentiate(method="autodiff", order=1)
>>> dk.differentiate(method="autodiff", order=2)
Notes:#
This backend is scalar-only. For gradients/Jacobians/Hessians of functions with vector inputs/outputs, use the standalone helpers in
derivkit.autodiff.jax_core(e.g.autodiff_gradient).To enable this backend, install the JAX extra:
pip install "derivkit[jax]".
- class derivkit.derivatives.autodiff.jax_autodiff.AutodiffDerivative(function: Callable[[float], Any], x0: float)#
Bases:
objectDerivativeKit engine for JAX-based autodiff.
Supports scalar functions f: R -> R with JAX-differentiable bodies.
Initializes the JAX autodiff derivative engine.
- differentiate(*, order: int = 1, **_: Any) float#
Computes the k-th derivative via JAX autodiff.
- Parameters:
order – Derivative order (>=1).
- Returns:
Derivative value as a float.
- derivkit.derivatives.autodiff.jax_autodiff.register_jax_autodiff_backend(*, name: str = 'autodiff', aliases: tuple[str, ...] = ('jax', 'jax-autodiff', 'jax-diff', 'jd')) None#
Registers the experimental JAX autodiff backend with DerivativeKit.
After calling this, you can use:
DerivativeKit(f, x0).differentiate(method=name, order=…)
- Parameters:
name – Name of the method to register.
aliases – Alternative names for the method.
- Returns:
None
- Raises:
AutodiffUnavailable – If JAX is not available.