ATE end-to-end examples (genriesz)
This notebook demonstrates how to estimate the Average Treatment Effect (ATE) with genriesz.
We assume the regressor has the form:
\(X = (D, Z)\), where \(D\) is a binary treatment indicator (\(0\) or \(1\)),
\(Y\) is the observed outcome.
We will compute (optionally with cross-fitting):
RA: regression adjustment (plug-in)
RW: Riesz weighting (weighting only)
ARW: augmented Riesz weighting
TMLE: targeted maximum likelihood estimation (one-step fluctuation)
We also show how to swap the basis:
polynomial features,
RKHS RBF random features,
nearest-neighbor matching (nearest-neighbor matching basis),
random forest leaf features (optional),
neural network embeddings (optional).
[1]:
import numpy as np
from genriesz import (
grr_ate,
SquaredGenerator,
UKLGenerator,
PolynomialBasis,
TreatmentInteractionBasis,
RBFRandomFourierBasis,
KNNCatchmentBasis,
)
rng = np.random.default_rng(0)
Synthetic data
[2]:
# Data-generating process
n = 3000
d_z = 5
Z = rng.normal(size=(n, d_z))
# Treatment assignment: e(Z) = sigmoid(a'Z)
logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]
e = 1.0 / (1.0 + np.exp(-logits))
D = rng.binomial(1, e, size=n).astype(float)
# Potential outcomes (constant effect for simplicity)
tau = 1.0
mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2
Y0 = mu0 + rng.normal(scale=1.0, size=n)
Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)
Y = D * Y1 + (1.0 - D) * Y0
# Regressor matrix X = [D, Z...]
X = np.column_stack([D, Z])
print("X shape:", X.shape, "Y shape:", Y.shape)
X shape: (3000, 6) Y shape: (3000,)
Example 1: Polynomial basis + treatment interactions
[3]:
# Basis on Z, then interact with D (ATE-friendly)
psi = PolynomialBasis(degree=2, include_bias=True)
phi = TreatmentInteractionBasis(base_basis=psi)
# Generator: Squared loss (always safe / no domain constraints)
gen = SquaredGenerator(C=0.0).as_generator()
res_poly = grr_ate(
X=X,
Y=Y,
basis=phi,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_poly.summary_text())
ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: alpha_abs_mean=2.0164751561831986, alpha_abs_p95=3.844795773570575, alpha_abs_max=9.334972319665178, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.01662007897774767, ess_treated=1188.4096325793182, ess_control=1260.0346792242883
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 1.03046 0.00359346 [ 1.02342, 1.03751] 0
RW 1.03302 0.0569279 [ 0.921445, 1.1446] 0
ARW 1.02916 0.0419132 [ 0.947017, 1.11131] 0
TMLE 1.02921 0.0419128 [ 0.947063, 1.11136] 0
Example 2: RKHS basis (RBF random Fourier features)
This approximates an RBF kernel feature map using random Fourier features, then interacts the features with treatment.
[4]:
psi_rff = RBFRandomFourierBasis(
n_features=500,
sigma=1.0,
standardize=True,
random_state=0,
)
phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)
res_rff = grr_ate(
X=X,
Y=Y,
basis=phi_rff,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_rff.summary_text())
ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: alpha_abs_mean=2.077446519847265, alpha_abs_p95=3.353495372876411, alpha_abs_max=4.936266573928845, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.27862992844919376, ess_treated=1363.470230442599, ess_control=1370.9575124421833
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 1.15729 0.00717955 [ 1.14321, 1.17136] 0
RW 1.22541 0.0540494 [ 1.11947, 1.33134] 0
ARW 1.12627 0.0423592 [ 1.04325, 1.20929] 0
TMLE 1.12732 0.0423287 [ 1.04436, 1.21028] 0
Example 3: KNN catchment basis (nearest-neighbor matching)
Nearest-neighbor matching can be expressed using a catchment-area basis
which assigns each point to its \(k\) nearest training centers. This is a special case of squared-loss Riesz regression, so we can pass it to grr_ate exactly like any other basis.
[ ]:
# KNN catchment basis as a nearest-neighbor Riesz basis.
#
# phi_j(z) = 1{j in NN_k(z)} assigns each point to its k nearest training centers.
# TreatmentInteractionBasis then creates [D*psi(Z), (1-D)*psi(Z)],
# which gives the standard NN-matching Riesz representer as a linear model.
#
# With cross-fitting the training fold becomes the centers, so the feature
# dimension p = n_train >> n_test. The dual (Woodbury) solve handles this
# in O(n_test^3) instead of O(n_train^3).
basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)
res_knn = grr_ate(
X=X,
Y=Y,
basis=phi_knn,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_knn.summary_text())
Example 4: Random forest leaf basis (optional)
If you have scikit-learn installed, you can use a random forest as a feature map via leaf indicators. This keeps GRR convex (linear in parameters) while giving a flexible nonparametric basis.
[6]:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis
rf = RandomForestRegressor(
n_estimators=200,
max_depth=6,
random_state=0,
)
leaf_basis = RandomForestLeafBasis(rf, include_bias=True).fit(X, Y)
res_rf = grr_ate(
X=X,
Y=Y,
basis=leaf_basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_rf.summary_text())
ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.786895849078183, ess_treated=350.8795988622065, ess_control=507.56068177539834
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 1.03187 0.0199079 [ 0.992854, 1.07089] 0
RW 1.69371 0.252221 [ 1.19937, 2.18806] 1.88e-11
ARW 1.25382 0.215261 [ 0.831919, 1.67573] 5.72e-09
TMLE 1.063 0.217126 [ 0.63744, 1.48856] 9.79e-07
Example 5: Neural network embedding basis (optional)
If you have PyTorch installed, you can use a neural network as a basis function. A recommended procedure is:
train an embedding network on a separate task,
use it as a basis function,
use its outputs as features in GRR.
Below we show the mechanics with a small MLP (training is optional).
[10]:
import torch
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis
torch.manual_seed(0)
net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)
# (Optional) Train net here on a separate task.
# For a lightweight demo, we skip training and just use the random initialization.
nn_basis = TorchEmbeddingBasis(net, include_bias=True, device="cpu")
res_nn = grr_ate(
X=X,
Y=Y,
basis=nn_basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_nn.summary_text())
ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.10074303607722811, ess_treated=1234.797658219504, ess_control=1289.906386252168
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 1.0526 0.00233879 [ 1.04801, 1.05718] 0
RW 1.02999 0.0497571 [ 0.932472, 1.12752] 0
ARW 1.02254 0.0372324 [ 0.949568, 1.09552] 0
TMLE 1.02144 0.0372205 [ 0.948489, 1.09439] 0
Generator (SQ / UKL / BP)
This section runs SQ-Riesz, UKL-Riesz, and BP-Riesz for the same polynomial interaction basis, and compares multiple regularization norms and strengths.
We use a branch function for UKL/BP that forces:
positive branch for treated units (
D=1),negative branch for control units (
D=0).
All four estimators (RA / RW / ARW / TMLE) are reported.
[6]:
from genriesz import BPGenerator
branch = lambda x: int(x[0] == 1.0)
generator_grid = [
("SQ", SquaredGenerator(C=0.0).as_generator()),
("UKL (C=1)", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),
("BP (omega=0.1, C=1)", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),
("BP (omega=0.2, C=1)", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),
("BP (omega=0.5, C=1)", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),
]
penalty_grid = [
{"penalty": "l2", "lam": 1e-4, "p_norm": None},
{"penalty": "l2", "lam": 1e-3, "p_norm": None},
{"penalty": "l1", "lam": 1e-4, "p_norm": None},
{"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]
rows = []
for gname, gen_i in generator_grid:
for cfg in penalty_grid:
res_i = grr_ate(
X=X,
Y=Y,
basis=phi,
generator=gen_i,
cross_fit=True,
folds=3,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
outcome_link="identity",
riesz_penalty=cfg["penalty"],
riesz_lam=cfg["lam"],
riesz_p_norm=cfg["p_norm"],
max_iter=250,
tol=1e-8,
)
row = {
"generator": gname,
"penalty": cfg["penalty"],
"lam": cfg["lam"],
"p_norm": cfg["p_norm"],
}
for k in ("ra", "rw", "arw", "tmle"):
e = res_i.estimates[k]
row[f"{k}"] = e.estimate
row[f"{k}_se"] = e.se
row[f"{k}_err"] = e.estimate - tau # tau is the true constant effect in this DGP
rows.append(row)
import pandas as pd
df = pd.DataFrame(rows)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
| generator | penalty | lam | p_norm | ra | ra_se | ra_err | rw | rw_se | rw_err | arw | arw_se | arw_err | tmle | tmle_se | tmle_err | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 11 | BP (omega=0.1, C=1) | l1 | 0.0010 | 1.5 | 1.027967 | 0.003809 | 0.027967 | 1.019822 | 0.059940 | 0.019822 | 1.039039 | 0.043709 | 0.039039 | 1.038188 | 0.043721 | 0.038188 |
| 15 | BP (omega=0.2, C=1) | l1 | 0.0010 | 1.5 | 1.027967 | 0.003809 | 0.027967 | 1.023228 | 0.059248 | 0.023228 | 1.039091 | 0.043380 | 0.039091 | 1.038355 | 0.043391 | 0.038355 |
| 14 | BP (omega=0.2, C=1) | l1 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.021144 | 0.059442 | 0.021144 | 1.039292 | 0.043518 | 0.039292 | 1.038505 | 0.043530 | 0.038505 |
| 13 | BP (omega=0.2, C=1) | l2 | 0.0010 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.022101 | 0.059411 | 0.022101 | 1.039302 | 0.043490 | 0.039302 | 1.038521 | 0.043502 | 0.038521 |
| 12 | BP (omega=0.2, C=1) | l2 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.021037 | 0.059462 | 0.021037 | 1.039308 | 0.043530 | 0.039308 | 1.038517 | 0.043542 | 0.038517 |
| 7 | UKL (C=1) | l1 | 0.0010 | 1.5 | 1.027967 | 0.003809 | 0.027967 | 1.015917 | 0.060709 | 0.015917 | 1.039322 | 0.044135 | 0.039322 | 1.038294 | 0.044140 | 0.038294 |
| 19 | BP (omega=0.5, C=1) | l1 | 0.0010 | 1.5 | 1.027967 | 0.003809 | 0.027967 | 1.029973 | 0.057821 | 0.029973 | 1.039372 | 0.042636 | 0.039372 | 1.038849 | 0.042643 | 0.038849 |
| 18 | BP (omega=0.5, C=1) | l1 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.027756 | 0.058034 | 0.027756 | 1.039449 | 0.042746 | 0.039449 | 1.038895 | 0.042754 | 0.038895 |
| 10 | BP (omega=0.1, C=1) | l1 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.018222 | 0.060047 | 0.018222 | 1.039456 | 0.043853 | 0.039456 | 1.038531 | 0.043864 | 0.038531 |
| 16 | BP (omega=0.5, C=1) | l2 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.027687 | 0.058050 | 0.027687 | 1.039457 | 0.042753 | 0.039457 | 1.038900 | 0.042761 | 0.038900 |
| 17 | BP (omega=0.5, C=1) | l2 | 0.0010 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.029078 | 0.057985 | 0.029078 | 1.039464 | 0.042703 | 0.039464 | 1.038918 | 0.042711 | 0.038918 |
| 9 | BP (omega=0.1, C=1) | l2 | 0.0010 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.019118 | 0.060007 | 0.019118 | 1.039485 | 0.043828 | 0.039485 | 1.038564 | 0.043839 | 0.038564 |
| 8 | BP (omega=0.1, C=1) | l2 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.018151 | 0.060053 | 0.018151 | 1.039499 | 0.043864 | 0.039499 | 1.038568 | 0.043875 | 0.038568 |
| 6 | UKL (C=1) | l1 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.014454 | 0.060770 | 0.014454 | 1.039918 | 0.044285 | 0.039918 | 1.038785 | 0.044288 | 0.038785 |
| 5 | UKL (C=1) | l2 | 0.0010 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.015266 | 0.060731 | 0.015266 | 1.039954 | 0.044264 | 0.039954 | 1.038824 | 0.044267 | 0.038824 |
| 4 | UKL (C=1) | l2 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.014382 | 0.060769 | 0.014382 | 1.039978 | 0.044297 | 0.039978 | 1.038835 | 0.044299 | 0.038835 |
| 3 | SQ | l1 | 0.0010 | 1.5 | 1.027967 | 0.003809 | 0.027967 | 1.033380 | 0.056470 | 0.033380 | 1.041214 | 0.041673 | 0.041214 | 1.040825 | 0.041676 | 0.040825 |
| 2 | SQ | l1 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.033103 | 0.056626 | 0.033103 | 1.041513 | 0.041770 | 0.041513 | 1.041085 | 0.041772 | 0.041085 |
| 0 | SQ | l2 | 0.0001 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.032923 | 0.056626 | 0.032923 | 1.041555 | 0.041760 | 0.041555 | 1.041128 | 0.041763 | 0.041128 |
| 1 | SQ | l2 | 0.0010 | NaN | 1.027967 | 0.003809 | 0.027967 | 1.031386 | 0.056477 | 0.031386 | 1.041641 | 0.041583 | 0.041641 | 1.041266 | 0.041585 | 0.041266 |
[ ]: