ATT simulation with (approximate) true value (genriesz)
This notebook demonstrates ATT estimation:
We generate a synthetic population with heterogeneous treatment effects, so in general ATT ≠ ATE. We compute an approximate “true” ATT by Monte Carlo from a large simulated population, and compare it to GRR-based estimators.
We assume the regressor has the form \(X = [D, Z...]\) where \(D\) is a binary treatment indicator.
[1]:
import numpy as np
from genriesz import (
grr_att,
SquaredGenerator,
UKLGenerator,
BPGenerator,
PolynomialBasis,
TreatmentInteractionBasis,
RBFRandomFourierBasis,
KNNCatchmentBasis,
)
rng = np.random.default_rng(0)
Data generating process
[2]:
def draw_population(n: int, d_z: int, seed: int = 0):
rng = np.random.default_rng(seed)
Z = rng.normal(size=(n, d_z))
logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]
e = 1.0 / (1.0 + np.exp(-logits))
D = rng.binomial(1, e, size=n).astype(int)
# Heterogeneous treatment effect
tau = 1.0 + 0.5 * Z[:, 0]
mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2
Y0 = mu0 + rng.normal(scale=1.0, size=n)
Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)
Y = D * Y1 + (1 - D) * Y0
X = np.column_stack([D.astype(float), Z])
return X, Y, Y0, Y1, D, tau
# Large population for an approximate truth
X_pop, Y_pop, Y0_pop, Y1_pop, D_pop, tau_pop = draw_population(n=200_000, d_z=5, seed=1)
true_att = float(np.mean((Y1_pop - Y0_pop)[D_pop == 1]))
true_ate = float(np.mean(Y1_pop - Y0_pop))
print("Approx. true ATT (Monte Carlo):", true_att)
print("Approx. true ATE (Monte Carlo):", true_ate)
Approx. true ATT (Monte Carlo): 1.158115237841178
Approx. true ATE (Monte Carlo): 1.0038058584371061
Example 1: Polynomial basis + treatment interactions
[3]:
# Sample a dataset from the same DGP
X, Y, Y0, Y1, D, tau = draw_population(n=5000, d_z=5, seed=0)
# Basis on Z, then interact with D (works well for treatment-effect functionals)
psi = PolynomialBasis(degree=2, include_bias=True)
phi = TreatmentInteractionBasis(base_basis=psi)
gen = SquaredGenerator(C=0.0).as_generator()
res = grr_att(
X=X,
Y=Y,
basis=phi,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res.summary_text())
ATT estimates (n=5000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.6705657462527745, max_abs_smd_weighted=0.005400089251320891, ess_treated=2500.896209806344, ess_control=1558.1240998523388
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 1.09227 0.0179759 [ 1.05703, 1.1275] 0
RW 1.08225 0.0522539 [ 0.979834, 1.18467] 0
ARW 1.09178 0.0377112 [ 1.01786, 1.16569] 0
TMLE 1.09178 0.0377068 [ 1.01788, 1.16569] 0
Example 2: RKHS basis (RBF random Fourier features)
This approximates an RBF kernel feature map using random Fourier features, then interacts the features with treatment.
[ ]:
psi_rff = RBFRandomFourierBasis(
n_features=500,
sigma=1.0,
standardize=True,
random_state=0,
)
phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)
res_rff = grr_att(
X=X,
Y=Y,
basis=phi_rff,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_rff.summary_text())
Example 3: KNN catchment basis (nearest-neighbor matching)
Nearest-neighbor matching as a special case of squared-loss Riesz regression. TreatmentInteractionBasis creates [D·ψ(Z), (1-D)·ψ(Z)], recovering the standard NN-matching Riesz representer as a linear model.
[ ]:
basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)
res_knn = grr_att(
X=X,
Y=Y,
basis=phi_knn,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_knn.summary_text())
Example 4: Random forest leaf basis (optional)
If you have scikit-learn installed, you can use a random forest as a feature map via leaf indicators. This keeps GRR convex while giving a flexible nonparametric basis.
[ ]:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis
rf = RandomForestRegressor(
n_estimators=200,
max_depth=6,
random_state=0,
)
leaf_basis = RandomForestLeafBasis(rf).fit(X, Y)
res_rf = grr_att(
X=X,
Y=Y,
basis=leaf_basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_rf.summary_text())
Example 5: Neural network embedding basis (optional)
If you have PyTorch installed, you can use a small MLP as a basis function. Below we use the random initialization for a lightweight demo.
[ ]:
import torch
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis
torch.manual_seed(0)
net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)
nn_basis = TorchEmbeddingBasis(net, include_bias=True, device="cpu")
res_nn = grr_att(
X=X,
Y=Y,
basis=nn_basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_nn.summary_text())
Generator / regularization sweep (SQ / UKL / BP)
We compare SQ-Riesz / UKL-Riesz / BP-Riesz under multiple regularization norms and strengths. For UKL/BP, we use a branch function that forces:
positive branch for treated units (\(D=1\)),
negative branch for control units (\(D=0\)),
which matches the sign structure of common treatment-effect Riesz representers.
We report RA / RW / ARW / TMLE and compare errors to the Monte Carlo “true” ATT.
[6]:
# Branch: + for treated, - for control (D is the first column of X).
branch = lambda x: int(x[0] == 1.0)
generator_grid = [
("SQ", SquaredGenerator(C=0.0).as_generator()),
("UKL (C=1)", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),
("BP (omega=0.1, C=1)", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),
("BP (omega=0.2, C=1)", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),
("BP (omega=0.5, C=1)", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),
]
penalty_grid = [
{"penalty": "l2", "lam": 1e-4, "p_norm": None},
{"penalty": "l2", "lam": 1e-3, "p_norm": None},
{"penalty": "l1", "lam": 1e-4, "p_norm": None},
{"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]
rows = []
for gname, gen_i in generator_grid:
for cfg in penalty_grid:
res_i = grr_att(
X=X,
Y=Y,
basis=phi,
generator=gen_i,
cross_fit=True,
folds=3,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
outcome_link="identity", # Y is unbounded
riesz_penalty=cfg["penalty"],
riesz_lam=cfg["lam"],
riesz_p_norm=cfg.get("p_norm"),
max_iter=250,
tol=1e-8,
)
row = {
"generator": gname,
"penalty": cfg["penalty"],
"lam": cfg["lam"],
}
for k in ("ra", "rw", "arw", "tmle"):
e = res_i.estimates[k]
row[f"{k}"] = e.estimate
row[f"{k}_se"] = e.se
row[f"{k}_err"] = e.estimate - true_att
rows.append(row)
import pandas as pd
df = pd.DataFrame(rows)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
| generator | penalty | lam | ra | ra_se | ra_err | rw | rw_se | rw_err | arw | arw_se | arw_err | tmle | tmle_se | tmle_err | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 11 | BP (omega=0.1, C=1) | l1 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 1.055435e+00 | 6.848838e-02 | -1.026806e-01 | 1.131193e+00 | 4.746753e-02 | -2.692224e-02 | 9.669026e+50 | 9.669025e+50 | 9.669026e+50 |
| 8 | BP (omega=0.1, C=1) | l2 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | 9.667499e-01 | 1.490316e-01 | -1.913653e-01 | 1.121730e+00 | 1.250208e-01 | -3.638561e-02 | 1.093675e+00 | 1.223439e-01 | -6.444051e-02 |
| 12 | BP (omega=0.2, C=1) | l2 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | 9.323105e-01 | 1.239197e-01 | -2.258047e-01 | 1.103421e+00 | 9.391140e-02 | -5.469465e-02 | 1.091852e+00 | 9.016083e-02 | -6.626373e-02 |
| 16 | BP (omega=0.5, C=1) | l2 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | 9.934959e-01 | 8.863373e-02 | -1.646193e-01 | 1.098363e+00 | 7.062806e-02 | -5.975265e-02 | 1.093084e+00 | 7.047579e-02 | -6.503108e-02 |
| 17 | BP (omega=0.5, C=1) | l2 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 1.115423e+00 | 5.939799e-02 | -4.269203e-02 | 1.090700e+00 | 4.538244e-02 | -6.741474e-02 | 1.090902e+00 | 4.537922e-02 | -6.721304e-02 |
| 5 | UKL (C=1) | l2 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 9.929128e-01 | 8.730633e-02 | -1.652024e-01 | 1.082664e+00 | 6.354118e-02 | -7.545118e-02 | 1.088462e+00 | 6.352715e-02 | -6.965350e-02 |
| 3 | SQ | l1 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 1.074762e+00 | 5.261753e-02 | -8.335342e-02 | 1.077929e+00 | 3.807590e-02 | -8.018587e-02 | 1.078362e+00 | 3.795229e-02 | -7.975364e-02 |
| 13 | BP (omega=0.2, C=1) | l2 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 1.079259e+00 | 6.341705e-02 | -7.885605e-02 | 1.077821e+00 | 4.803699e-02 | -8.029404e-02 | 1.083387e+00 | 4.796557e-02 | -7.472860e-02 |
| 2 | SQ | l1 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | 1.075152e+00 | 5.269233e-02 | -8.296282e-02 | 1.077776e+00 | 3.812085e-02 | -8.033969e-02 | 1.078236e+00 | 3.799601e-02 | -7.987879e-02 |
| 1 | SQ | l2 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 1.073219e+00 | 5.256672e-02 | -8.489596e-02 | 1.077763e+00 | 3.800344e-02 | -8.035199e-02 | 1.078169e+00 | 3.787773e-02 | -7.994649e-02 |
| 0 | SQ | l2 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | 1.075010e+00 | 5.268673e-02 | -8.310482e-02 | 1.077762e+00 | 3.811353e-02 | -8.035305e-02 | 1.078220e+00 | 3.798852e-02 | -7.989492e-02 |
| 9 | BP (omega=0.1, C=1) | l2 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | 1.133463e+00 | 6.554599e-02 | -2.465228e-02 | 1.077047e+00 | 4.953755e-02 | -8.106850e-02 | 1.083433e+00 | 4.940757e-02 | -7.468217e-02 |
| 4 | UKL (C=1) | l2 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | -4.913791e-01 | 8.986323e-01 | -1.649494e+00 | 1.272292e+00 | 6.737185e-01 | 1.141767e-01 | 1.094470e+00 | 6.736889e-01 | -6.364497e-02 |
| 18 | BP (omega=0.5, C=1) | l1 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | -1.307005e+17 | 9.786285e+16 | -1.307005e+17 | -8.159928e+16 | 6.899664e+16 | -8.159928e+16 | 1.088943e+00 | 1.309872e+16 | -6.917228e-02 |
| 19 | BP (omega=0.5, C=1) | l1 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | -1.316292e+18 | 1.119353e+18 | -1.316292e+18 | 7.034867e+17 | 7.937043e+17 | 7.034867e+17 | 1.091431e+00 | 1.548131e+17 | -6.668454e-02 |
| 14 | BP (omega=0.2, C=1) | l1 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | -5.591832e+28 | 4.945093e+28 | -5.591832e+28 | 4.957689e+27 | 8.465837e+27 | 4.957689e+27 | 1.091323e+00 | 4.258426e+27 | -6.679187e-02 |
| 15 | BP (omega=0.2, C=1) | l1 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | -1.359638e+35 | 1.350624e+35 | -1.359638e+35 | -5.928117e+34 | 6.163495e+34 | -5.928117e+34 | 1.091295e+00 | 3.255808e+33 | -6.681974e-02 |
| 10 | BP (omega=0.1, C=1) | l1 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | 5.327378e+53 | 5.327359e+53 | 5.327378e+53 | 6.551305e+53 | 6.551264e+53 | 6.551305e+53 | 1.567020e+01 | 5.749404e+48 | 1.451209e+01 |
| 7 | UKL (C=1) | l1 | 0.0010 | 1.091298 | 0.018015 | -0.066818 | -1.987935e+301 | NaN | -1.987935e+301 | -2.759111e+300 | NaN | -2.759111e+300 | 1.091298e+00 | NaN | -6.681756e-02 |
| 6 | UKL (C=1) | l1 | 0.0001 | 1.091298 | 0.018015 | -0.066818 | -2.664446e+301 | NaN | -2.664446e+301 | -1.057640e+301 | NaN | -1.057640e+301 | 1.091298e+00 | NaN | -6.681756e-02 |
[ ]: