Writing a custom adapter¶
Most users will not need this — pymargins ships adapters for
statsmodels, linearmodels, and lifelines. If you have a model class
none of those cover, the adapter interface is small.
The four adapter base classes, in increasing order of work:
Base class |
When to use |
|---|---|
|
|
|
|
|
black-box predict, but |
|
refit-and-resample is the only viable path |
Minimal GLMAdapter example¶
A GLMAdapter subclass must implement four methods:
detect(model)— returnTrueif this adapter handles the fitted object.variable_info(model)— return a list ofVariableInfoobjects describing covariates (names and whether they are"continuous"or"categorical").design_matrix(model, data)— build the design matrix from apd.DataFrame.link_inverse(eta)andlink_inverse_deriv(eta)— the mean function and its derivative, written withjax.numpyoperations.
from pymargins import GLMAdapter, VariableInfo, register_adapter
import jax.numpy as jnp
class MyGLMAdapter(GLMAdapter):
def detect(self, model):
return isinstance(model, MyModel)
def variable_info(self, model):
return [VariableInfo(name=n, kind="continuous") for n in model.feature_names_]
def design_matrix(self, model, data):
return model.build_design(data)
def link_inverse(self, eta):
return jnp.exp(eta) / (1 + jnp.exp(eta))
def link_inverse_deriv(self, eta):
p = self.link_inverse(eta)
return p * (1 - p)
register_adapter(MyGLMAdapter())
After registration, Margins(...) will auto-detect your model class.
Adapters are tried in registration order; user-registered adapters are
tried before built-ins, so you can override default detection if your
fitted object subclasses one of the supported types but has different
predict semantics.
WrappedFDAdapter — when you have a black-box predict¶
If your model already has a predict(data) method but you cannot
reimplement it in JAX, subclass WrappedFDAdapter. You only need to
expose the linear predictor η = X β:
from pymargins import WrappedFDAdapter, VariableInfo, register_adapter
class MyWrappedAdapter(WrappedFDAdapter):
def detect(self, model):
return isinstance(model, MyBlackBoxModel)
def variable_info(self, model):
return [VariableInfo(name=n, kind="continuous") for n in model.cols]
def design_matrix(self, model, data):
return data[model.cols].values
def predict(self, model, data):
# Black-box: can use numpy, sklearn, etc.
return model.predict(data)
register_adapter(MyWrappedAdapter())
The engine wraps predict in a custom-JVP finite-difference primitive
and differentiates through it. Gradient quality is good; Hessian
quality is acceptable for κ. If you need exact Hessians, consider
reimplementing the predict path in JAX and upgrading to GLMAdapter.
Testing your adapter¶
Before registering, test the adapter standalone:
adapter = MyGLMAdapter()
assert adapter.detect(my_fitted_model)
X = adapter.design_matrix(my_fitted_model, df[:5])
assert X.shape == (5, len(adapter.variable_info(my_fitted_model)))
# Prediction round-trip
beta = adapter.coefficients()
preds = adapter.predict(beta, X)
assert preds.shape == (5,)
See The adapter pattern for the full contract and the rationale behind the four-base-class split.