pymargins.make_predict_with_fd_jvp¶
- pymargins.make_predict_with_fd_jvp(predict_native: Callable, fd_step: float = 1e-06) Callable¶
Wrap a non-JAX predict function as a JAX primitive with FD-based JVP.
Used by ModelAdapter implementations for black-box models whose predict function cannot be reimplemented in JAX. The returned function accepts JAX arrays and is fully compatible with jax.grad, jax.hessian, and jax.jvp. Internally, it uses central-difference FD to compute directional derivatives at the model boundary; downstream autodiff over the estimand structure remains exact.
This is the cleanest way to integrate a non-differentiable model into the JAX-based inference pipeline. The custom JVP isolates the FD to the one operation that needs it; everything composed with this primitive benefits from exact autodiff.
- Parameters:
predict_native (callable (beta_np, X) -> array_np) – Native prediction function. Receives NumPy beta of shape (n_params,) and an arbitrary X (typically a NumPy array or pandas DataFrame), returns a NumPy array of predictions.
fd_step (float, default 1e-6) – FD step for directional derivatives. The default is appropriate for float64; if the model has internal numerical solvers with looser tolerances, this may need to be increased.
- Returns:
predict_wrapped – JAX-compatible wrapper. jax.grad, jax.hessian, and jax.jvp all work through this function. Both β and X can be differentiated against (the latter is needed for dydx slopes).
- Return type:
callable (beta_jax, X) -> array_jax
Notes
Offset handling: If the model uses an offset (e.g. exposure in Poisson models), the offset must be baked into
predict_nativeat the call site where this factory is invoked. The wrapped function signature is(beta, X); there is nooffsetargument because JAX traces both arguments and an offset would need to be a traced array (not a staticnondiff_argnum) to support differentiation through it.