Nonlinear estimands with evaluate

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from pymargins import Margins

rng = np.random.default_rng(42)
n = 2000
df = pd.DataFrame({
    "age": rng.integers(20, 75, n),
    "treatment": rng.binomial(1, 0.40, n),
    "dose": rng.choice([0, 50, 100], n),
    "policy": rng.choice(["A", "B"], n),
})
lp = (-1.5 + 0.04 * df["age"] + 0.8 * df["treatment"]
      + 0.01 * df["dose"]
      + 0.3 * (df["policy"] == "B"))
df["y"] = rng.binomial(1, 1 / (1 + np.exp(-lp)))

fit = smf.glm("y ~ age + treatment + dose + C(policy)", data=df,
              family=sm.families.Binomial()).fit()
m = Margins.linear_scale(fit, at="overall")

Margins.evaluate is the escape hatch for estimands that cannot be written as a weighted sum of scenario predictions. Use it for reciprocals, custom utility functions, ratios of differences, or any other JAX-differentiable composition that is not linear in the predictions.

For simple differences, risk ratios, odds ratios, and difference-in-differences, contrasts is the better tool — it is faster, more transparent, and usually more accurate. See Linear contrasts with contrasts for those recipes and Contrasts vs evaluate — choosing the right tool for the full decision guide.

How evaluate works

evaluate takes a list of scenarios and a compose function. The engine:

  1. Computes the response-scale prediction for each scenario.

  2. Stacks them into a JAX array and passes them to compose.

  3. Applies phi_inv to lift the result onto the inference scale.

  4. Runs delta-method inference (or simulation/bootstrap if curvature is high or compose is not JAX-differentiable).

  5. Back-transforms CI endpoints with phi for reporting.

Mathematically:

result = φ( φ⁻¹( compose(p₁, p₂, …, p_k) ) )

where pᵢ is the aggregated response-scale prediction for scenario i.

Number needed to treat (NNT)

NNT is the reciprocal of the absolute risk reduction. Because it is a reciprocal, it cannot be written as a linear contrast and must go through evaluate:

from pymargins import Margins

m = Margins.linear_scale(fit, at="overall")

scenarios = [
    {"atexog": {"treatment": 1}, "label": "treated"},
    {"atexog": {"treatment": 0}, "label": "control"},
]

res = m.evaluate(
    scenarios=scenarios,
    compose=lambda p: 1.0 / (p[0] - p[1]),
)
print(res.summary())
/home/hunter/Workspace/pymargins/pymargins/margins/_session.py:1143: UserWarning: Delta-method curvature κ=0.323 exceeds threshold (0.3); falling back to simulation.
  result_data = run_inference(
==============================================================
           Margins Result (simulation, level=0.95)            
==============================================================
         estimate  std err  statistic  P>|z|  [95% Conf. Int.]
--------------------------------------------------------------
treated    8.8186   1.5392     8.8186  0.000   6.7749, 12.6234
==============================================================

n = 2000
WARNING — Fallback triggered: kappa=0.323>threshold=0.3
κ: 0.323

If the risk difference crosses zero, the denominator can change sign and κ will be large. In that case pymargins auto-falls back to simulation, which is the safe thing to do for a reciprocal.

Raw ratio on the linear scale

A risk ratio is usually computed with contrasts on a log_scale session (log(p₁) log(p₀) back-transformed with exp). That is the preferred path because the log-ratio is linear and the delta method is exact.

Use evaluate for the raw ratio p₁ / p₀ only when your field or journal explicitly requires inference on the ratio scale itself:

m = Margins.linear_scale(fit, at="overall")

scenarios = [
    {"atexog": {"treatment": 1}, "label": "treated"},
    {"atexog": {"treatment": 0}, "label": "control"},
]

res = m.evaluate(
    scenarios=scenarios,
    compose=lambda p: p[0] / p[1],
)
print(res.summary())
============================================================
             Margins Result (delta, level=0.95)             
============================================================
         estimate  std err        z  P>|z|  [95% Conf. Int.]
------------------------------------------------------------
treated    1.1549   0.0255  45.2943  0.000    1.1049, 1.2048
============================================================

n = 2000
κ: 0.049
Delta-vs-sim disagreement: 0.406%

Because the ratio is nonlinear on the linear scale, κ is usually larger than for the log-scale contrast and the CI is wider.

Ratio of differences (Emax-style)

When the estimand is a ratio in which the numerator and denominator are themselves differences, evaluate is required:

scenarios = [
    {"atexog": {"dose": 0}, "label": "placebo"},
    {"atexog": {"dose": 50}, "label": "low"},
    {"atexog": {"dose": 100}, "label": "high"},
]

# Emax-style parameter: (high − placebo) / (low − placebo)
res = m.evaluate(
    scenarios=scenarios,
    compose=lambda p: (p[2] - p[0]) / (p[1] - p[0]),
)

Custom utility / welfare function

Suppose you have a utility function u(p) = p**0.5 (a concave transformation of a predicted probability) and you want the expected utility difference between two policy regimes:

import jax.numpy as jnp

m = Margins.linear_scale(fit, at="overall")

scenarios = [
    {"atexog": {"policy": "A"}, "label": "regime_A"},
    {"atexog": {"policy": "B"}, "label": "regime_B"},
]

res = m.evaluate(
    scenarios=scenarios,
    compose=lambda p: jnp.sqrt(p[0]) - jnp.sqrt(p[1]),
)
print(res.summary())
=============================================================
              Margins Result (delta, level=0.95)             
=============================================================
          estimate  std err        z  P>|z|  [95% Conf. Int.]
-------------------------------------------------------------
regime_A   -0.0289   0.0100  -2.8761  0.004  -0.0486, -0.0092
=============================================================

n = 2000
κ: 0.035
Delta-vs-sim disagreement: 8.712%

When evaluate auto-routes to simulation

If compose uses Python control flow (if, for) on tracer values, JAX cannot differentiate it and the delta method is impossible. pymargins catches the error and silently reroutes to the session’s simulation or bootstrap method. The result records the realized method, so the audit trail is still complete.

To avoid auto-routing, write compose with JAX primitives:

Instead of

Use

if x > 0: ...

jnp.where(x > 0, ..., ...)

x / y (unsafe divide)

jnp.where(y == 0, 0, x / y)

max(a, b)

jnp.maximum(a, b)

See also