derivkit.derivatives.autodiff.jax_autodiff module#

JAX-based autodiff backend for DerivativeKit.

This backend is intentionally minimal: it only supports scalar derivatives $f: R mapsto R$ via JAX autodiff, and must be registered explicitly as explained in the example.

Example:#

Basic usage (opt-in registration):

>>> from derivkit.derivative_kit import DerivativeKit
>>> from derivkit.derivatives.autodiff.jax_autodiff import register_jax_autodiff_backend
>>> register_jax_autodiff_backend()
>>>
>>> def func(x):
...     import jax.numpy as jnp
...     return jnp.sin(x) + 0.5 * x**2
...
>>> dk = DerivativeKit(func, 1.0)
>>> dk.differentiate(method="autodiff", order=1)
>>> dk.differentiate(method="autodiff", order=2)

Notes:#

  • This backend is scalar-only. For gradients/Jacobians/Hessians of functions with vector inputs/outputs, use the standalone helpers in derivkit.autodiff.jax_core (e.g. autodiff_gradient).

  • To enable this backend, install the JAX extra: pip install "derivkit[jax]".

class derivkit.derivatives.autodiff.jax_autodiff.AutodiffDerivative(function: Callable[[float], Any], x0: float)#

Bases: object

DerivativeKit engine for JAX-based autodiff.

Supports scalar functions f: R -> R with JAX-differentiable bodies.

Initializes the JAX autodiff derivative engine.

differentiate(*, order: int = 1, **_: Any) float#

Computes the k-th derivative via JAX autodiff.

Parameters:

order – Derivative order (>=1).

Returns:

Derivative value as a float.

derivkit.derivatives.autodiff.jax_autodiff.register_jax_autodiff_backend(*, name: str = 'autodiff', aliases: tuple[str, ...] = ('jax', 'jax-autodiff', 'jax-diff', 'jd')) None#

Registers the experimental JAX autodiff backend with DerivativeKit.

After calling this, you can use:

DerivativeKit(f, x0).differentiate(method=name, order=…)

Parameters:
  • name – Name of the method to register.

  • aliases – Alternative names for the method.

Returns:

None

Raises:

AutodiffUnavailable – If JAX is not available.