AME end-to-end example (genriesz)

This notebook demonstrates how to estimate an Average Marginal Effect (AME), i.e., an average derivative of the outcome regression function.

We simulate

\[Y = \sin(X_0) + 0.5 X_1^2 + \varepsilon,\]

so the true AME for coordinate 0 is

\[\mathbb{E}[\partial_{x_0} \gamma(X)] = \mathbb{E}[\cos(X_0)].\]

If \(X_0 \sim N(0,1)\), then \(\mathbb{E}[\cos(X_0)] = \exp(-1/2)\).

[1]:
import numpy as np
from genriesz import (
    grr_ame,
    SquaredGenerator,
    UKLGenerator,
    BPGenerator,
    PolynomialBasis,
    RBFRandomFourierBasis,
)

rng = np.random.default_rng(0)

Synthetic data with known true AME

[2]:
n = 4000
d = 3

X = rng.normal(size=(n, d))
eps = rng.normal(scale=1.0, size=n)

Y = np.sin(X[:, 0]) + 0.5 * (X[:, 1] ** 2) + eps

true_ame0 = float(np.exp(-0.5))  # E[cos(N(0,1))]
print("Approx. true AME for coordinate 0:", true_ame0)

Approx. true AME for coordinate 0: 0.6065306597126334

Example 1: Polynomial basis

[3]:
# A simple polynomial basis on X
basis = PolynomialBasis(degree=3, include_bias=True)

gen = SquaredGenerator(C=0.0).as_generator()

res = grr_ame(
    X=X,
    Y=Y,
    coordinate=0,
    basis=basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res.summary_text())

AME(coord=0) estimates (n=4000)
alpha=0.05 | null=0.0
diagnostics: alpha_abs_mean=0.7986024158798944, alpha_abs_p95=1.9652359398234305, alpha_abs_max=4.506401055576029

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                0.568023    0.00659669       [ 0.555094,  0.580952]           0
RW                0.569997     0.0255014       [ 0.520015,  0.619979]           0
ARW               0.568096     0.0169592       [ 0.534856,  0.601335]           0
TMLE              0.568094     0.0169599       [ 0.534853,  0.601334]           0

Example 2: RKHS basis (RBF random Fourier features)

RBF random Fourier features are smooth and differentiable, so the AME derivative d phi / d x_j is implemented analytically.

[ ]:
psi_rff = RBFRandomFourierBasis(
    n_features=500,
    sigma=1.0,
    standardize=True,
    random_state=0,
)

res_rff = grr_ame(
    X=X,
    Y=Y,
    coordinate=0,
    basis=psi_rff,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_rff.summary_text())

Note: bases requiring smooth derivatives

AME requires basis.derivative(X, coordinate). Piecewise-constant bases (KNNCatchmentBasis, RandomForestLeafBasis, TorchEmbeddingBasis without autograd) do not implement this method and therefore cannot be used with grr_ame. Use PolynomialBasis or RBFRandomFourierBasis (or any smooth basis) instead.

Generator sweep (SQ / UKL / BP)

Below we compare SQ-Riesz, UKL-Riesz, and BP-Riesz under multiple regularization norms and strengths. We report RA / RW / ARW / TMLE and the error against the known true AME.

[4]:
# A small grid over generators and regularization.
generator_grid = [
    ("SQ", SquaredGenerator(C=0.0).as_generator()),
    ("UKL (C=0)", UKLGenerator(C=0.0).as_generator()),
    ("BP (omega=0.1, C=0)", BPGenerator(C=0.0, omega=0.1).as_generator()),
    ("BP (omega=0.2, C=0)", BPGenerator(C=0.0, omega=0.2).as_generator()),
    ("BP (omega=0.5, C=0)", BPGenerator(C=0.0, omega=0.5).as_generator()),
]

penalty_grid = [
    {"penalty": "l2", "lam": 1e-4, "p_norm": None},
    {"penalty": "l2", "lam": 1e-3, "p_norm": None},
    {"penalty": "l1", "lam": 1e-4, "p_norm": None},
    {"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]

rows = []
for gname, gen_i in generator_grid:
    for cfg in penalty_grid:
        res_i = grr_ame(
            X=X,
            Y=Y,
            coordinate=0,
            basis=basis,
            generator=gen_i,
            cross_fit=True,
            folds=3,               # smaller folds for the sweep
            random_state=0,
            estimators=("ra", "rw", "arw", "tmle"),
            outcome_models="shared",
            outcome_link="identity",  # Y is unbounded, so Gaussian TMLE is appropriate
            riesz_penalty=cfg["penalty"],
            riesz_lam=cfg["lam"],
            riesz_p_norm=cfg.get("p_norm"),
            max_iter=250,
            tol=1e-8,
        )

        row = {
            "generator": gname,
            "penalty": cfg["penalty"],
            "lam": cfg["lam"],
        }

        for k in ("ra", "rw", "arw", "tmle"):
            e = res_i.estimates[k]
            row[f"{k}"] = e.estimate
            row[f"{k}_se"] = e.se
            row[f"{k}_err"] = e.estimate - true_ame0

        rows.append(row)

import pandas as pd

df = pd.DataFrame(rows)
# Sort by absolute ARW error (ARW is typically stable)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:6: UserWarning: UKLGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| > C + 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
  ("UKL (C=0)", UKLGenerator(C=0.0).as_generator()),
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:7: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
  ("BP (omega=0.1, C=0)", BPGenerator(C=0.0, omega=0.1).as_generator()),
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:8: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
  ("BP (omega=0.2, C=0)", BPGenerator(C=0.0, omega=0.2).as_generator()),
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:9: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
  ("BP (omega=0.5, C=0)", BPGenerator(C=0.0, omega=0.5).as_generator()),
generator penalty lam ra ra_se ra_err rw rw_se rw_err arw arw_se arw_err tmle tmle_se tmle_err
19 BP (omega=0.5, C=0) l1 0.0010 0.56645 0.006685 -0.040081 0.583555 0.026705 -0.022976 0.582595 0.019655 -0.023935 0.569430 0.019682 -0.037101
17 BP (omega=0.5, C=0) l2 0.0010 0.56645 0.006685 -0.040081 0.584187 0.026752 -0.022344 0.581640 0.019659 -0.024891 0.569258 0.019684 -0.037273
16 BP (omega=0.5, C=0) l2 0.0001 0.56645 0.006685 -0.040081 0.583067 0.026758 -0.023463 0.581609 0.019660 -0.024921 0.569253 0.019686 -0.037278
18 BP (omega=0.5, C=0) l1 0.0001 0.56645 0.006685 -0.040081 0.583012 0.026763 -0.023518 0.580149 0.019657 -0.026382 0.568983 0.019680 -0.037547
9 BP (omega=0.1, C=0) l2 0.0010 0.56645 0.006685 -0.040081 0.578461 0.027031 -0.028070 0.578600 0.019622 -0.027931 0.568672 0.019629 -0.037858
5 UKL (C=0) l2 0.0010 0.56645 0.006685 -0.040081 0.578919 0.027191 -0.027611 0.578271 0.019641 -0.028260 0.568613 0.019639 -0.037917
6 UKL (C=0) l1 0.0001 0.56645 0.006685 -0.040081 0.578065 0.027179 -0.028466 0.578176 0.019640 -0.028355 0.568595 0.019637 -0.037935
10 BP (omega=0.1, C=0) l1 0.0001 0.56645 0.006685 -0.040081 0.578894 0.027017 -0.027637 0.577886 0.019615 -0.028645 0.568540 0.019622 -0.037991
8 BP (omega=0.1, C=0) l2 0.0001 0.56645 0.006685 -0.040081 0.577012 0.027054 -0.029518 0.577699 0.019618 -0.028832 0.568507 0.019625 -0.038024
13 BP (omega=0.2, C=0) l2 0.0010 0.56645 0.006685 -0.040081 0.578107 0.026959 -0.028423 0.577668 0.019615 -0.028863 0.568504 0.019627 -0.038026
12 BP (omega=0.2, C=0) l2 0.0001 0.56645 0.006685 -0.040081 0.576705 0.026951 -0.029826 0.577231 0.019613 -0.029299 0.568424 0.019624 -0.038107
4 UKL (C=0) l2 0.0001 0.56645 0.006685 -0.040081 0.575639 0.027207 -0.030892 0.576981 0.019636 -0.029550 0.568375 0.019634 -0.038156
14 BP (omega=0.2, C=0) l1 0.0001 0.56645 0.006685 -0.040081 0.576432 0.026943 -0.030099 0.576170 0.019608 -0.030361 0.568229 0.019618 -0.038301
7 UKL (C=0) l1 0.0010 0.56645 0.006685 -0.040081 0.574841 0.027141 -0.031690 0.575717 0.019627 -0.030813 0.568142 0.019625 -0.038389
11 BP (omega=0.1, C=0) l1 0.0010 0.56645 0.006685 -0.040081 0.575375 0.026993 -0.031156 0.575641 0.019606 -0.030890 0.568128 0.019610 -0.038403
15 BP (omega=0.2, C=0) l1 0.0010 0.56645 0.006685 -0.040081 0.572222 0.026900 -0.034309 0.573946 0.019594 -0.032585 0.567820 0.019601 -0.038710
1 SQ l2 0.0010 0.56645 0.006685 -0.040081 0.575361 0.026199 -0.031170 0.569663 0.017072 -0.036868 0.569537 0.017102 -0.036993
0 SQ l2 0.0001 0.56645 0.006685 -0.040081 0.576791 0.026184 -0.029739 0.569646 0.017092 -0.036885 0.569517 0.017121 -0.037014
2 SQ l1 0.0001 0.56645 0.006685 -0.040081 0.576779 0.026169 -0.029752 0.569638 0.017091 -0.036893 0.569510 0.017121 -0.037021
3 SQ l1 0.0010 0.56645 0.006685 -0.040081 0.575307 0.026053 -0.031223 0.569633 0.017068 -0.036898 0.569511 0.017097 -0.037019
[ ]:

[ ]: