ATE end-to-end examples (genriesz)

This notebook demonstrates how to estimate the Average Treatment Effect (ATE) with genriesz.

We assume the regressor has the form:

  • \(X = (D, Z)\), where \(D\) is a binary treatment indicator (\(0\) or \(1\)),

  • \(Y\) is the observed outcome.

We will compute (optionally with cross-fitting):

  • RA: regression adjustment (plug-in)

  • RW: Riesz weighting (weighting only)

  • ARW: augmented Riesz weighting

  • TMLE: targeted maximum likelihood estimation (one-step fluctuation)

We also show how to swap the basis:

  • polynomial features,

  • RKHS RBF random features,

  • nearest-neighbor matching (nearest-neighbor matching basis),

  • random forest leaf features (optional),

  • neural network embeddings (optional).

[1]:
import numpy as np

from genriesz import (
    grr_ate,
    SquaredGenerator,
    UKLGenerator,
    PolynomialBasis,
    TreatmentInteractionBasis,
    RBFRandomFourierBasis,
    KNNCatchmentBasis,
)

rng = np.random.default_rng(0)

Synthetic data

[2]:
# Data-generating process
n = 3000
d_z = 5

Z = rng.normal(size=(n, d_z))

# Treatment assignment: e(Z) = sigmoid(a'Z)
logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]
e = 1.0 / (1.0 + np.exp(-logits))
D = rng.binomial(1, e, size=n).astype(float)

# Potential outcomes (constant effect for simplicity)
tau = 1.0
mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2
Y0 = mu0 + rng.normal(scale=1.0, size=n)
Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)

Y = D * Y1 + (1.0 - D) * Y0

# Regressor matrix X = [D, Z...]
X = np.column_stack([D, Z])

print("X shape:", X.shape, "Y shape:", Y.shape)

X shape: (3000, 6) Y shape: (3000,)

Example 1: Polynomial basis + treatment interactions

[3]:
# Basis on Z, then interact with D (ATE-friendly)
psi = PolynomialBasis(degree=2, include_bias=True)
phi = TreatmentInteractionBasis(base_basis=psi)

# Generator: Squared loss (always safe / no domain constraints)
gen = SquaredGenerator(C=0.0).as_generator()

res_poly = grr_ate(
    X=X,
    Y=Y,
    basis=phi,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_poly.summary_text())

ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: alpha_abs_mean=2.0164751561831986, alpha_abs_p95=3.844795773570575, alpha_abs_max=9.334972319665178, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.01662007897774767, ess_treated=1188.4096325793182, ess_control=1260.0346792242883

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                 1.03046    0.00359346         [ 1.02342,  1.03751]           0
RW                 1.03302     0.0569279         [ 0.921445,  1.1446]           0
ARW                1.02916     0.0419132        [ 0.947017,  1.11131]           0
TMLE               1.02921     0.0419128        [ 0.947063,  1.11136]           0

Example 2: RKHS basis (RBF random Fourier features)

This approximates an RBF kernel feature map using random Fourier features, then interacts the features with treatment.

[4]:
psi_rff = RBFRandomFourierBasis(
    n_features=500,
    sigma=1.0,
    standardize=True,
    random_state=0,
)
phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)

res_rff = grr_ate(
    X=X,
    Y=Y,
    basis=phi_rff,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_rff.summary_text())

ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: alpha_abs_mean=2.077446519847265, alpha_abs_p95=3.353495372876411, alpha_abs_max=4.936266573928845, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.27862992844919376, ess_treated=1363.470230442599, ess_control=1370.9575124421833

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                 1.15729    0.00717955         [ 1.14321,  1.17136]           0
RW                 1.22541     0.0540494         [ 1.11947,  1.33134]           0
ARW                1.12627     0.0423592         [ 1.04325,  1.20929]           0
TMLE               1.12732     0.0423287         [ 1.04436,  1.21028]           0

Example 3: KNN catchment basis (nearest-neighbor matching)

Nearest-neighbor matching can be expressed using a catchment-area basis

\[\phi_j(z) = \mathbf{1}\{j \in \mathrm{NN}_k(z)\},\]

which assigns each point to its \(k\) nearest training centers. This is a special case of squared-loss Riesz regression, so we can pass it to grr_ate exactly like any other basis.

[ ]:
# KNN catchment basis as a nearest-neighbor Riesz basis.
#
# phi_j(z) = 1{j in NN_k(z)} assigns each point to its k nearest training centers.
# TreatmentInteractionBasis then creates [D*psi(Z), (1-D)*psi(Z)],
# which gives the standard NN-matching Riesz representer as a linear model.
#
# With cross-fitting the training fold becomes the centers, so the feature
# dimension p = n_train >> n_test.  The dual (Woodbury) solve handles this
# in O(n_test^3) instead of O(n_train^3).

basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi_knn   = TreatmentInteractionBasis(base_basis=basis_knn)

res_knn = grr_ate(
    X=X,
    Y=Y,
    basis=phi_knn,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_knn.summary_text())

Example 4: Random forest leaf basis (optional)

If you have scikit-learn installed, you can use a random forest as a feature map via leaf indicators. This keeps GRR convex (linear in parameters) while giving a flexible nonparametric basis.

[6]:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis

rf = RandomForestRegressor(
    n_estimators=200,
    max_depth=6,
    random_state=0,
)

leaf_basis = RandomForestLeafBasis(rf, include_bias=True).fit(X, Y)

res_rf = grr_ate(
    X=X,
    Y=Y,
    basis=leaf_basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_rf.summary_text())
ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.786895849078183, ess_treated=350.8795988622065, ess_control=507.56068177539834

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                 1.03187     0.0199079        [ 0.992854,  1.07089]           0
RW                 1.69371      0.252221         [ 1.19937,  2.18806]    1.88e-11
ARW                1.25382      0.215261        [ 0.831919,  1.67573]    5.72e-09
TMLE                 1.063      0.217126         [ 0.63744,  1.48856]    9.79e-07

Example 5: Neural network embedding basis (optional)

If you have PyTorch installed, you can use a neural network as a basis function. A recommended procedure is:

  1. train an embedding network on a separate task,

  2. use it as a basis function,

  3. use its outputs as features in GRR.

Below we show the mechanics with a small MLP (training is optional).

[10]:
import torch
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis

torch.manual_seed(0)

net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)
# (Optional) Train net here on a separate task.
# For a lightweight demo, we skip training and just use the random initialization.
nn_basis = TorchEmbeddingBasis(net, include_bias=True, device="cpu")

res_nn = grr_ate(
    X=X,
    Y=Y,
    basis=nn_basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_nn.summary_text())
ATE estimates (n=3000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.10074303607722811, ess_treated=1234.797658219504, ess_control=1289.906386252168

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                  1.0526    0.00233879         [ 1.04801,  1.05718]           0
RW                 1.02999     0.0497571        [ 0.932472,  1.12752]           0
ARW                1.02254     0.0372324        [ 0.949568,  1.09552]           0
TMLE               1.02144     0.0372205        [ 0.948489,  1.09439]           0

Generator (SQ / UKL / BP)

This section runs SQ-Riesz, UKL-Riesz, and BP-Riesz for the same polynomial interaction basis, and compares multiple regularization norms and strengths.

We use a branch function for UKL/BP that forces:

  • positive branch for treated units (D=1),

  • negative branch for control units (D=0).

All four estimators (RA / RW / ARW / TMLE) are reported.

[6]:
from genriesz import BPGenerator

branch = lambda x: int(x[0] == 1.0)

generator_grid = [
    ("SQ", SquaredGenerator(C=0.0).as_generator()),
    ("UKL (C=1)", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),
    ("BP (omega=0.1, C=1)", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),
    ("BP (omega=0.2, C=1)", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),
    ("BP (omega=0.5, C=1)", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),
]

penalty_grid = [
    {"penalty": "l2", "lam": 1e-4, "p_norm": None},
    {"penalty": "l2", "lam": 1e-3, "p_norm": None},
    {"penalty": "l1", "lam": 1e-4, "p_norm": None},
    {"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]

rows = []
for gname, gen_i in generator_grid:
    for cfg in penalty_grid:
        res_i = grr_ate(
            X=X,
            Y=Y,
            basis=phi,
            generator=gen_i,
            cross_fit=True,
            folds=3,
            random_state=0,
            estimators=("ra", "rw", "arw", "tmle"),
            outcome_models="shared",
            outcome_link="identity",
            riesz_penalty=cfg["penalty"],
            riesz_lam=cfg["lam"],
            riesz_p_norm=cfg["p_norm"],
            max_iter=250,
            tol=1e-8,
        )

        row = {
            "generator": gname,
            "penalty": cfg["penalty"],
            "lam": cfg["lam"],
            "p_norm": cfg["p_norm"],
        }
        for k in ("ra", "rw", "arw", "tmle"):
            e = res_i.estimates[k]
            row[f"{k}"] = e.estimate
            row[f"{k}_se"] = e.se
            row[f"{k}_err"] = e.estimate - tau  # tau is the true constant effect in this DGP
        rows.append(row)

import pandas as pd

df = pd.DataFrame(rows)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
generator penalty lam p_norm ra ra_se ra_err rw rw_se rw_err arw arw_se arw_err tmle tmle_se tmle_err
11 BP (omega=0.1, C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 1.019822 0.059940 0.019822 1.039039 0.043709 0.039039 1.038188 0.043721 0.038188
15 BP (omega=0.2, C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 1.023228 0.059248 0.023228 1.039091 0.043380 0.039091 1.038355 0.043391 0.038355
14 BP (omega=0.2, C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 1.021144 0.059442 0.021144 1.039292 0.043518 0.039292 1.038505 0.043530 0.038505
13 BP (omega=0.2, C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 1.022101 0.059411 0.022101 1.039302 0.043490 0.039302 1.038521 0.043502 0.038521
12 BP (omega=0.2, C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 1.021037 0.059462 0.021037 1.039308 0.043530 0.039308 1.038517 0.043542 0.038517
7 UKL (C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 1.015917 0.060709 0.015917 1.039322 0.044135 0.039322 1.038294 0.044140 0.038294
19 BP (omega=0.5, C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 1.029973 0.057821 0.029973 1.039372 0.042636 0.039372 1.038849 0.042643 0.038849
18 BP (omega=0.5, C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 1.027756 0.058034 0.027756 1.039449 0.042746 0.039449 1.038895 0.042754 0.038895
10 BP (omega=0.1, C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 1.018222 0.060047 0.018222 1.039456 0.043853 0.039456 1.038531 0.043864 0.038531
16 BP (omega=0.5, C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 1.027687 0.058050 0.027687 1.039457 0.042753 0.039457 1.038900 0.042761 0.038900
17 BP (omega=0.5, C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 1.029078 0.057985 0.029078 1.039464 0.042703 0.039464 1.038918 0.042711 0.038918
9 BP (omega=0.1, C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 1.019118 0.060007 0.019118 1.039485 0.043828 0.039485 1.038564 0.043839 0.038564
8 BP (omega=0.1, C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 1.018151 0.060053 0.018151 1.039499 0.043864 0.039499 1.038568 0.043875 0.038568
6 UKL (C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 1.014454 0.060770 0.014454 1.039918 0.044285 0.039918 1.038785 0.044288 0.038785
5 UKL (C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 1.015266 0.060731 0.015266 1.039954 0.044264 0.039954 1.038824 0.044267 0.038824
4 UKL (C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 1.014382 0.060769 0.014382 1.039978 0.044297 0.039978 1.038835 0.044299 0.038835
3 SQ l1 0.0010 1.5 1.027967 0.003809 0.027967 1.033380 0.056470 0.033380 1.041214 0.041673 0.041214 1.040825 0.041676 0.040825
2 SQ l1 0.0001 NaN 1.027967 0.003809 0.027967 1.033103 0.056626 0.033103 1.041513 0.041770 0.041513 1.041085 0.041772 0.041085
0 SQ l2 0.0001 NaN 1.027967 0.003809 0.027967 1.032923 0.056626 0.032923 1.041555 0.041760 0.041555 1.041128 0.041763 0.041128
1 SQ l2 0.0010 NaN 1.027967 0.003809 0.027967 1.031386 0.056477 0.031386 1.041641 0.041583 0.041641 1.041266 0.041585 0.041266
[ ]: