ATT simulation with (approximate) true value (genriesz)

This notebook demonstrates ATT estimation:

\[\theta = \mathbb{E}[Y(1)-Y(0) \mid D=1].\]

We generate a synthetic population with heterogeneous treatment effects, so in general ATT ≠ ATE. We compute an approximate “true” ATT by Monte Carlo from a large simulated population, and compare it to GRR-based estimators.

We assume the regressor has the form \(X = [D, Z...]\) where \(D\) is a binary treatment indicator.

[1]:
import numpy as np

from genriesz import (
    grr_att,
    SquaredGenerator,
    UKLGenerator,
    BPGenerator,
    PolynomialBasis,
    TreatmentInteractionBasis,
    RBFRandomFourierBasis,
    KNNCatchmentBasis,
)

rng = np.random.default_rng(0)

Data generating process

[2]:
def draw_population(n: int, d_z: int, seed: int = 0):
    rng = np.random.default_rng(seed)
    Z = rng.normal(size=(n, d_z))

    logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]
    e = 1.0 / (1.0 + np.exp(-logits))
    D = rng.binomial(1, e, size=n).astype(int)

    # Heterogeneous treatment effect
    tau = 1.0 + 0.5 * Z[:, 0]
    mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2

    Y0 = mu0 + rng.normal(scale=1.0, size=n)
    Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)
    Y = D * Y1 + (1 - D) * Y0

    X = np.column_stack([D.astype(float), Z])
    return X, Y, Y0, Y1, D, tau

# Large population for an approximate truth
X_pop, Y_pop, Y0_pop, Y1_pop, D_pop, tau_pop = draw_population(n=200_000, d_z=5, seed=1)
true_att = float(np.mean((Y1_pop - Y0_pop)[D_pop == 1]))
true_ate = float(np.mean(Y1_pop - Y0_pop))

print("Approx. true ATT (Monte Carlo):", true_att)
print("Approx. true ATE (Monte Carlo):", true_ate)

Approx. true ATT (Monte Carlo): 1.158115237841178
Approx. true ATE (Monte Carlo): 1.0038058584371061

Example 1: Polynomial basis + treatment interactions

[3]:
# Sample a dataset from the same DGP
X, Y, Y0, Y1, D, tau = draw_population(n=5000, d_z=5, seed=0)

# Basis on Z, then interact with D (works well for treatment-effect functionals)
psi = PolynomialBasis(degree=2, include_bias=True)
phi = TreatmentInteractionBasis(base_basis=psi)

gen = SquaredGenerator(C=0.0).as_generator()

res = grr_att(
    X=X,
    Y=Y,
    basis=phi,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res.summary_text())

ATT estimates (n=5000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.6705657462527745, max_abs_smd_weighted=0.005400089251320891, ess_treated=2500.896209806344, ess_control=1558.1240998523388

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                 1.09227     0.0179759          [ 1.05703,  1.1275]           0
RW                 1.08225     0.0522539        [ 0.979834,  1.18467]           0
ARW                1.09178     0.0377112         [ 1.01786,  1.16569]           0
TMLE               1.09178     0.0377068         [ 1.01788,  1.16569]           0

Example 2: RKHS basis (RBF random Fourier features)

This approximates an RBF kernel feature map using random Fourier features, then interacts the features with treatment.

[ ]:
psi_rff = RBFRandomFourierBasis(
    n_features=500,
    sigma=1.0,
    standardize=True,
    random_state=0,
)
phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)

res_rff = grr_att(
    X=X,
    Y=Y,
    basis=phi_rff,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_rff.summary_text())

Example 3: KNN catchment basis (nearest-neighbor matching)

Nearest-neighbor matching as a special case of squared-loss Riesz regression. TreatmentInteractionBasis creates [D·ψ(Z), (1-D)·ψ(Z)], recovering the standard NN-matching Riesz representer as a linear model.

[ ]:
basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi_knn   = TreatmentInteractionBasis(base_basis=basis_knn)

res_knn = grr_att(
    X=X,
    Y=Y,
    basis=phi_knn,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_knn.summary_text())

Example 4: Random forest leaf basis (optional)

If you have scikit-learn installed, you can use a random forest as a feature map via leaf indicators. This keeps GRR convex while giving a flexible nonparametric basis.

[ ]:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis

rf = RandomForestRegressor(
    n_estimators=200,
    max_depth=6,
    random_state=0,
)

leaf_basis = RandomForestLeafBasis(rf).fit(X, Y)

res_rf = grr_att(
    X=X,
    Y=Y,
    basis=leaf_basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_rf.summary_text())

Example 5: Neural network embedding basis (optional)

If you have PyTorch installed, you can use a small MLP as a basis function. Below we use the random initialization for a lightweight demo.

[ ]:
import torch
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis

torch.manual_seed(0)

net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)
nn_basis = TorchEmbeddingBasis(net, include_bias=True, device="cpu")

res_nn = grr_att(
    X=X,
    Y=Y,
    basis=nn_basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_nn.summary_text())

Generator / regularization sweep (SQ / UKL / BP)

We compare SQ-Riesz / UKL-Riesz / BP-Riesz under multiple regularization norms and strengths. For UKL/BP, we use a branch function that forces:

  • positive branch for treated units (\(D=1\)),

  • negative branch for control units (\(D=0\)),

which matches the sign structure of common treatment-effect Riesz representers.

We report RA / RW / ARW / TMLE and compare errors to the Monte Carlo “true” ATT.

[6]:
# Branch: + for treated, - for control (D is the first column of X).
branch = lambda x: int(x[0] == 1.0)

generator_grid = [
    ("SQ", SquaredGenerator(C=0.0).as_generator()),
    ("UKL (C=1)", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),
    ("BP (omega=0.1, C=1)", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),
    ("BP (omega=0.2, C=1)", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),
    ("BP (omega=0.5, C=1)", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),
]

penalty_grid = [
    {"penalty": "l2", "lam": 1e-4, "p_norm": None},
    {"penalty": "l2", "lam": 1e-3, "p_norm": None},
    {"penalty": "l1", "lam": 1e-4, "p_norm": None},
    {"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]

rows = []
for gname, gen_i in generator_grid:
    for cfg in penalty_grid:
        res_i = grr_att(
            X=X,
            Y=Y,
            basis=phi,
            generator=gen_i,
            cross_fit=True,
            folds=3,
            random_state=0,
            estimators=("ra", "rw", "arw", "tmle"),
            outcome_models="shared",
            outcome_link="identity",  # Y is unbounded
            riesz_penalty=cfg["penalty"],
            riesz_lam=cfg["lam"],
            riesz_p_norm=cfg.get("p_norm"),
            max_iter=250,
            tol=1e-8,
        )

        row = {
            "generator": gname,
            "penalty": cfg["penalty"],
            "lam": cfg["lam"],
        }

        for k in ("ra", "rw", "arw", "tmle"):
            e = res_i.estimates[k]
            row[f"{k}"] = e.estimate
            row[f"{k}_se"] = e.se
            row[f"{k}_err"] = e.estimate - true_att

        rows.append(row)

import pandas as pd

df = pd.DataFrame(rows)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
generator penalty lam ra ra_se ra_err rw rw_se rw_err arw arw_se arw_err tmle tmle_se tmle_err
11 BP (omega=0.1, C=1) l1 0.0010 1.091298 0.018015 -0.066818 1.055435e+00 6.848838e-02 -1.026806e-01 1.131193e+00 4.746753e-02 -2.692224e-02 9.669026e+50 9.669025e+50 9.669026e+50
8 BP (omega=0.1, C=1) l2 0.0001 1.091298 0.018015 -0.066818 9.667499e-01 1.490316e-01 -1.913653e-01 1.121730e+00 1.250208e-01 -3.638561e-02 1.093675e+00 1.223439e-01 -6.444051e-02
12 BP (omega=0.2, C=1) l2 0.0001 1.091298 0.018015 -0.066818 9.323105e-01 1.239197e-01 -2.258047e-01 1.103421e+00 9.391140e-02 -5.469465e-02 1.091852e+00 9.016083e-02 -6.626373e-02
16 BP (omega=0.5, C=1) l2 0.0001 1.091298 0.018015 -0.066818 9.934959e-01 8.863373e-02 -1.646193e-01 1.098363e+00 7.062806e-02 -5.975265e-02 1.093084e+00 7.047579e-02 -6.503108e-02
17 BP (omega=0.5, C=1) l2 0.0010 1.091298 0.018015 -0.066818 1.115423e+00 5.939799e-02 -4.269203e-02 1.090700e+00 4.538244e-02 -6.741474e-02 1.090902e+00 4.537922e-02 -6.721304e-02
5 UKL (C=1) l2 0.0010 1.091298 0.018015 -0.066818 9.929128e-01 8.730633e-02 -1.652024e-01 1.082664e+00 6.354118e-02 -7.545118e-02 1.088462e+00 6.352715e-02 -6.965350e-02
3 SQ l1 0.0010 1.091298 0.018015 -0.066818 1.074762e+00 5.261753e-02 -8.335342e-02 1.077929e+00 3.807590e-02 -8.018587e-02 1.078362e+00 3.795229e-02 -7.975364e-02
13 BP (omega=0.2, C=1) l2 0.0010 1.091298 0.018015 -0.066818 1.079259e+00 6.341705e-02 -7.885605e-02 1.077821e+00 4.803699e-02 -8.029404e-02 1.083387e+00 4.796557e-02 -7.472860e-02
2 SQ l1 0.0001 1.091298 0.018015 -0.066818 1.075152e+00 5.269233e-02 -8.296282e-02 1.077776e+00 3.812085e-02 -8.033969e-02 1.078236e+00 3.799601e-02 -7.987879e-02
1 SQ l2 0.0010 1.091298 0.018015 -0.066818 1.073219e+00 5.256672e-02 -8.489596e-02 1.077763e+00 3.800344e-02 -8.035199e-02 1.078169e+00 3.787773e-02 -7.994649e-02
0 SQ l2 0.0001 1.091298 0.018015 -0.066818 1.075010e+00 5.268673e-02 -8.310482e-02 1.077762e+00 3.811353e-02 -8.035305e-02 1.078220e+00 3.798852e-02 -7.989492e-02
9 BP (omega=0.1, C=1) l2 0.0010 1.091298 0.018015 -0.066818 1.133463e+00 6.554599e-02 -2.465228e-02 1.077047e+00 4.953755e-02 -8.106850e-02 1.083433e+00 4.940757e-02 -7.468217e-02
4 UKL (C=1) l2 0.0001 1.091298 0.018015 -0.066818 -4.913791e-01 8.986323e-01 -1.649494e+00 1.272292e+00 6.737185e-01 1.141767e-01 1.094470e+00 6.736889e-01 -6.364497e-02
18 BP (omega=0.5, C=1) l1 0.0001 1.091298 0.018015 -0.066818 -1.307005e+17 9.786285e+16 -1.307005e+17 -8.159928e+16 6.899664e+16 -8.159928e+16 1.088943e+00 1.309872e+16 -6.917228e-02
19 BP (omega=0.5, C=1) l1 0.0010 1.091298 0.018015 -0.066818 -1.316292e+18 1.119353e+18 -1.316292e+18 7.034867e+17 7.937043e+17 7.034867e+17 1.091431e+00 1.548131e+17 -6.668454e-02
14 BP (omega=0.2, C=1) l1 0.0001 1.091298 0.018015 -0.066818 -5.591832e+28 4.945093e+28 -5.591832e+28 4.957689e+27 8.465837e+27 4.957689e+27 1.091323e+00 4.258426e+27 -6.679187e-02
15 BP (omega=0.2, C=1) l1 0.0010 1.091298 0.018015 -0.066818 -1.359638e+35 1.350624e+35 -1.359638e+35 -5.928117e+34 6.163495e+34 -5.928117e+34 1.091295e+00 3.255808e+33 -6.681974e-02
10 BP (omega=0.1, C=1) l1 0.0001 1.091298 0.018015 -0.066818 5.327378e+53 5.327359e+53 5.327378e+53 6.551305e+53 6.551264e+53 6.551305e+53 1.567020e+01 5.749404e+48 1.451209e+01
7 UKL (C=1) l1 0.0010 1.091298 0.018015 -0.066818 -1.987935e+301 NaN -1.987935e+301 -2.759111e+300 NaN -2.759111e+300 1.091298e+00 NaN -6.681756e-02
6 UKL (C=1) l1 0.0001 1.091298 0.018015 -0.066818 -2.664446e+301 NaN -2.664446e+301 -1.057640e+301 NaN -1.057640e+301 1.091298e+00 NaN -6.681756e-02
[ ]: