pymargins.make_glm_jvp_wrapper

pymargins.make_glm_jvp_wrapper(family) Callable

Wrap a GLM prediction with a custom JVP using the link’s analytical derivative.

For any GLM with mean function μ = g⁻¹(η) where η = Xβ, the gradient w.r.t. β is (dg⁻¹/dη at η) · X. This wrapper implements both the forward evaluation and the tangent using JAX-native operations for common links, making it fully compatible with jax.grad, jax.hessian, jax.jvp, and jax.vmap.

Parameters:

family (statsmodels.genmod.families.Family) – The fitted model’s family object. Must use a link supported by _jax_link_inverse and _jax_link_inverse_deriv.

Returns:

predict_wrapped – JAX-compatible prediction. Supports an optional offset added to the linear predictor before applying the link inverse. Note: offset is passed as a keyword-default arg; while this works with jax.grad/jvp/hessian today, jax.jit or nondiff_argnums may require baking offset usage into the factory call site.

Return type:

callable (beta, X, offset=None) -> array