pymargins.make_predict_with_fd_jvp

pymargins.make_predict_with_fd_jvp(predict_native: Callable, fd_step: float = 1e-06) Callable

Wrap a non-JAX predict function as a JAX primitive with FD-based JVP.

Used by ModelAdapter implementations for black-box models whose predict function cannot be reimplemented in JAX. The returned function accepts JAX arrays and is fully compatible with jax.grad, jax.hessian, and jax.jvp. Internally, it uses central-difference FD to compute directional derivatives at the model boundary; downstream autodiff over the estimand structure remains exact.

This is the cleanest way to integrate a non-differentiable model into the JAX-based inference pipeline. The custom JVP isolates the FD to the one operation that needs it; everything composed with this primitive benefits from exact autodiff.

Parameters:
  • predict_native (callable (beta_np, X) -> array_np) – Native prediction function. Receives NumPy beta of shape (n_params,) and an arbitrary X (typically a NumPy array or pandas DataFrame), returns a NumPy array of predictions.

  • fd_step (float, default 1e-6) – FD step for directional derivatives. The default is appropriate for float64; if the model has internal numerical solvers with looser tolerances, this may need to be increased.

Returns:

predict_wrapped – JAX-compatible wrapper. jax.grad, jax.hessian, and jax.jvp all work through this function. Both β and X can be differentiated against (the latter is needed for dydx slopes).

Return type:

callable (beta_jax, X) -> array_jax

Notes

Offset handling: If the model uses an offset (e.g. exposure in Poisson models), the offset must be baked into predict_native at the call site where this factory is invoked. The wrapped function signature is (beta, X); there is no offset argument because JAX traces both arguments and an offset would need to be a traced array (not a static nondiff_argnum) to support differentiation through it.