AME end-to-end example (genriesz)
This notebook demonstrates how to estimate an Average Marginal Effect (AME), i.e., an average derivative of the outcome regression function.
We simulate
so the true AME for coordinate 0 is
If \(X_0 \sim N(0,1)\), then \(\mathbb{E}[\cos(X_0)] = \exp(-1/2)\).
[1]:
import numpy as np
from genriesz import (
grr_ame,
SquaredGenerator,
UKLGenerator,
BPGenerator,
PolynomialBasis,
RBFRandomFourierBasis,
)
rng = np.random.default_rng(0)
Synthetic data with known true AME
[2]:
n = 4000
d = 3
X = rng.normal(size=(n, d))
eps = rng.normal(scale=1.0, size=n)
Y = np.sin(X[:, 0]) + 0.5 * (X[:, 1] ** 2) + eps
true_ame0 = float(np.exp(-0.5)) # E[cos(N(0,1))]
print("Approx. true AME for coordinate 0:", true_ame0)
Approx. true AME for coordinate 0: 0.6065306597126334
Example 1: Polynomial basis
[3]:
# A simple polynomial basis on X
basis = PolynomialBasis(degree=3, include_bias=True)
gen = SquaredGenerator(C=0.0).as_generator()
res = grr_ame(
X=X,
Y=Y,
coordinate=0,
basis=basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res.summary_text())
AME(coord=0) estimates (n=4000)
alpha=0.05 | null=0.0
diagnostics: alpha_abs_mean=0.7986024158798944, alpha_abs_p95=1.9652359398234305, alpha_abs_max=4.506401055576029
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 0.568023 0.00659669 [ 0.555094, 0.580952] 0
RW 0.569997 0.0255014 [ 0.520015, 0.619979] 0
ARW 0.568096 0.0169592 [ 0.534856, 0.601335] 0
TMLE 0.568094 0.0169599 [ 0.534853, 0.601334] 0
Example 2: RKHS basis (RBF random Fourier features)
RBF random Fourier features are smooth and differentiable, so the AME derivative d phi / d x_j is implemented analytically.
[ ]:
psi_rff = RBFRandomFourierBasis(
n_features=500,
sigma=1.0,
standardize=True,
random_state=0,
)
res_rff = grr_ame(
X=X,
Y=Y,
coordinate=0,
basis=psi_rff,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_rff.summary_text())
Note: bases requiring smooth derivatives
AME requires basis.derivative(X, coordinate). Piecewise-constant bases (KNNCatchmentBasis, RandomForestLeafBasis, TorchEmbeddingBasis without autograd) do not implement this method and therefore cannot be used with grr_ame. Use PolynomialBasis or RBFRandomFourierBasis (or any smooth basis) instead.
Generator sweep (SQ / UKL / BP)
Below we compare SQ-Riesz, UKL-Riesz, and BP-Riesz under multiple regularization norms and strengths. We report RA / RW / ARW / TMLE and the error against the known true AME.
[4]:
# A small grid over generators and regularization.
generator_grid = [
("SQ", SquaredGenerator(C=0.0).as_generator()),
("UKL (C=0)", UKLGenerator(C=0.0).as_generator()),
("BP (omega=0.1, C=0)", BPGenerator(C=0.0, omega=0.1).as_generator()),
("BP (omega=0.2, C=0)", BPGenerator(C=0.0, omega=0.2).as_generator()),
("BP (omega=0.5, C=0)", BPGenerator(C=0.0, omega=0.5).as_generator()),
]
penalty_grid = [
{"penalty": "l2", "lam": 1e-4, "p_norm": None},
{"penalty": "l2", "lam": 1e-3, "p_norm": None},
{"penalty": "l1", "lam": 1e-4, "p_norm": None},
{"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]
rows = []
for gname, gen_i in generator_grid:
for cfg in penalty_grid:
res_i = grr_ame(
X=X,
Y=Y,
coordinate=0,
basis=basis,
generator=gen_i,
cross_fit=True,
folds=3, # smaller folds for the sweep
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
outcome_link="identity", # Y is unbounded, so Gaussian TMLE is appropriate
riesz_penalty=cfg["penalty"],
riesz_lam=cfg["lam"],
riesz_p_norm=cfg.get("p_norm"),
max_iter=250,
tol=1e-8,
)
row = {
"generator": gname,
"penalty": cfg["penalty"],
"lam": cfg["lam"],
}
for k in ("ra", "rw", "arw", "tmle"):
e = res_i.estimates[k]
row[f"{k}"] = e.estimate
row[f"{k}_se"] = e.se
row[f"{k}_err"] = e.estimate - true_ame0
rows.append(row)
import pandas as pd
df = pd.DataFrame(rows)
# Sort by absolute ARW error (ARW is typically stable)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:6: UserWarning: UKLGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| > C + 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
("UKL (C=0)", UKLGenerator(C=0.0).as_generator()),
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:7: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
("BP (omega=0.1, C=0)", BPGenerator(C=0.0, omega=0.1).as_generator()),
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:8: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
("BP (omega=0.2, C=0)", BPGenerator(C=0.0, omega=0.2).as_generator()),
/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:9: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.
("BP (omega=0.5, C=0)", BPGenerator(C=0.0, omega=0.5).as_generator()),
| generator | penalty | lam | ra | ra_se | ra_err | rw | rw_se | rw_err | arw | arw_se | arw_err | tmle | tmle_se | tmle_err | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 19 | BP (omega=0.5, C=0) | l1 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.583555 | 0.026705 | -0.022976 | 0.582595 | 0.019655 | -0.023935 | 0.569430 | 0.019682 | -0.037101 |
| 17 | BP (omega=0.5, C=0) | l2 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.584187 | 0.026752 | -0.022344 | 0.581640 | 0.019659 | -0.024891 | 0.569258 | 0.019684 | -0.037273 |
| 16 | BP (omega=0.5, C=0) | l2 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.583067 | 0.026758 | -0.023463 | 0.581609 | 0.019660 | -0.024921 | 0.569253 | 0.019686 | -0.037278 |
| 18 | BP (omega=0.5, C=0) | l1 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.583012 | 0.026763 | -0.023518 | 0.580149 | 0.019657 | -0.026382 | 0.568983 | 0.019680 | -0.037547 |
| 9 | BP (omega=0.1, C=0) | l2 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.578461 | 0.027031 | -0.028070 | 0.578600 | 0.019622 | -0.027931 | 0.568672 | 0.019629 | -0.037858 |
| 5 | UKL (C=0) | l2 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.578919 | 0.027191 | -0.027611 | 0.578271 | 0.019641 | -0.028260 | 0.568613 | 0.019639 | -0.037917 |
| 6 | UKL (C=0) | l1 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.578065 | 0.027179 | -0.028466 | 0.578176 | 0.019640 | -0.028355 | 0.568595 | 0.019637 | -0.037935 |
| 10 | BP (omega=0.1, C=0) | l1 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.578894 | 0.027017 | -0.027637 | 0.577886 | 0.019615 | -0.028645 | 0.568540 | 0.019622 | -0.037991 |
| 8 | BP (omega=0.1, C=0) | l2 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.577012 | 0.027054 | -0.029518 | 0.577699 | 0.019618 | -0.028832 | 0.568507 | 0.019625 | -0.038024 |
| 13 | BP (omega=0.2, C=0) | l2 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.578107 | 0.026959 | -0.028423 | 0.577668 | 0.019615 | -0.028863 | 0.568504 | 0.019627 | -0.038026 |
| 12 | BP (omega=0.2, C=0) | l2 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.576705 | 0.026951 | -0.029826 | 0.577231 | 0.019613 | -0.029299 | 0.568424 | 0.019624 | -0.038107 |
| 4 | UKL (C=0) | l2 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.575639 | 0.027207 | -0.030892 | 0.576981 | 0.019636 | -0.029550 | 0.568375 | 0.019634 | -0.038156 |
| 14 | BP (omega=0.2, C=0) | l1 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.576432 | 0.026943 | -0.030099 | 0.576170 | 0.019608 | -0.030361 | 0.568229 | 0.019618 | -0.038301 |
| 7 | UKL (C=0) | l1 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.574841 | 0.027141 | -0.031690 | 0.575717 | 0.019627 | -0.030813 | 0.568142 | 0.019625 | -0.038389 |
| 11 | BP (omega=0.1, C=0) | l1 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.575375 | 0.026993 | -0.031156 | 0.575641 | 0.019606 | -0.030890 | 0.568128 | 0.019610 | -0.038403 |
| 15 | BP (omega=0.2, C=0) | l1 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.572222 | 0.026900 | -0.034309 | 0.573946 | 0.019594 | -0.032585 | 0.567820 | 0.019601 | -0.038710 |
| 1 | SQ | l2 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.575361 | 0.026199 | -0.031170 | 0.569663 | 0.017072 | -0.036868 | 0.569537 | 0.017102 | -0.036993 |
| 0 | SQ | l2 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.576791 | 0.026184 | -0.029739 | 0.569646 | 0.017092 | -0.036885 | 0.569517 | 0.017121 | -0.037014 |
| 2 | SQ | l1 | 0.0001 | 0.56645 | 0.006685 | -0.040081 | 0.576779 | 0.026169 | -0.029752 | 0.569638 | 0.017091 | -0.036893 | 0.569510 | 0.017121 | -0.037021 |
| 3 | SQ | l1 | 0.0010 | 0.56645 | 0.006685 | -0.040081 | 0.575307 | 0.026053 | -0.031223 | 0.569633 | 0.017068 | -0.036898 | 0.569511 | 0.017097 | -0.037019 |
[ ]:
[ ]: