{ "cells": [ { "cell_type": "markdown", "id": "f58299f4", "metadata": {}, "source": [ "# ATE end-to-end examples (genriesz)\n", "\n", "This notebook demonstrates how to estimate the **Average Treatment Effect (ATE)** with **genriesz**.\n", "\n", "We assume the regressor has the form:\n", "\n", "- $X = (D, Z)$, where $D$ is a **binary treatment indicator** ($0$ or $1$),\n", "- $Y$ is the observed outcome.\n", "\n", "We will compute (optionally with cross-fitting):\n", "\n", "- **RA**: regression adjustment (plug-in)\n", "- **RW**: Riesz weighting (weighting only)\n", "- **ARW**: augmented Riesz weighting\n", "- **TMLE**: targeted maximum likelihood estimation (one-step fluctuation)\n", "\n", "We also show how to swap the **basis**:\n", "- polynomial features,\n", "- RKHS RBF random features,\n", "- nearest-neighbor matching (nearest-neighbor matching basis),\n", "- random forest leaf features (optional),\n", "- neural network embeddings (optional).\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "71628cec", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from genriesz import (\n", " grr_ate,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " PolynomialBasis,\n", " TreatmentInteractionBasis,\n", " RBFRandomFourierBasis,\n", " KNNCatchmentBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)\n" ] }, { "cell_type": "markdown", "id": "041e5044", "metadata": {}, "source": [ "## Synthetic data" ] }, { "cell_type": "code", "execution_count": 2, "id": "f744ecd2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X shape: (3000, 6) Y shape: (3000,)\n" ] } ], "source": [ "# Data-generating process\n", "n = 3000\n", "d_z = 5\n", "\n", "Z = rng.normal(size=(n, d_z))\n", "\n", "# Treatment assignment: e(Z) = sigmoid(a'Z)\n", "logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]\n", "e = 1.0 / (1.0 + np.exp(-logits))\n", "D = rng.binomial(1, e, size=n).astype(float)\n", "\n", "# Potential outcomes (constant effect for simplicity)\n", "tau = 1.0\n", "mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2\n", "Y0 = mu0 + rng.normal(scale=1.0, size=n)\n", "Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)\n", "\n", "Y = D * Y1 + (1.0 - D) * Y0\n", "\n", "# Regressor matrix X = [D, Z...]\n", "X = np.column_stack([D, Z])\n", "\n", "print(\"X shape:\", X.shape, \"Y shape:\", Y.shape)\n" ] }, { "cell_type": "markdown", "id": "1277ad0c", "metadata": {}, "source": [ "## Example 1: Polynomial basis + treatment interactions" ] }, { "cell_type": "code", "execution_count": 3, "id": "31a5fe31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: alpha_abs_mean=2.0164751561831986, alpha_abs_p95=3.844795773570575, alpha_abs_max=9.334972319665178, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.01662007897774767, ess_treated=1188.4096325793182, ess_control=1260.0346792242883\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.03046 0.00359346 [ 1.02342, 1.03751] 0\n", "RW 1.03302 0.0569279 [ 0.921445, 1.1446] 0\n", "ARW 1.02916 0.0419132 [ 0.947017, 1.11131] 0\n", "TMLE 1.02921 0.0419128 [ 0.947063, 1.11136] 0\n" ] } ], "source": [ "# Basis on Z, then interact with D (ATE-friendly)\n", "psi = PolynomialBasis(degree=2, include_bias=True)\n", "phi = TreatmentInteractionBasis(base_basis=psi)\n", "\n", "# Generator: Squared loss (always safe / no domain constraints)\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res_poly = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=phi,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_poly.summary_text())\n" ] }, { "cell_type": "markdown", "id": "c848c3df", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)\n", "\n", "This approximates an RBF kernel feature map using random Fourier features, then\n", "interacts the features with treatment.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "e0c9683d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: alpha_abs_mean=2.077446519847265, alpha_abs_p95=3.353495372876411, alpha_abs_max=4.936266573928845, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.27862992844919376, ess_treated=1363.470230442599, ess_control=1370.9575124421833\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.15729 0.00717955 [ 1.14321, 1.17136] 0\n", "RW 1.22541 0.0540494 [ 1.11947, 1.33134] 0\n", "ARW 1.12627 0.0423592 [ 1.04325, 1.20929] 0\n", "TMLE 1.12732 0.0423287 [ 1.04436, 1.21028] 0\n" ] } ], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)\n", "\n", "res_rff = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=phi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rff.summary_text())\n" ] }, { "cell_type": "markdown", "id": "109b8823", "metadata": {}, "source": "## Example 3: KNN catchment basis (nearest-neighbor matching)\n\nNearest-neighbor matching can be expressed using a **catchment-area basis**\n\n$$\n\\phi_j(z) = \\mathbf{1}\\{j \\in \\mathrm{NN}_k(z)\\},\n$$\n\nwhich assigns each point to its $k$ nearest training centers.\nThis is a special case of squared-loss Riesz regression, so we can pass it to\n`grr_ate` exactly like any other basis.\n" }, { "cell_type": "code", "execution_count": null, "id": "e50f0e43", "metadata": {}, "outputs": [], "source": "# KNN catchment basis as a nearest-neighbor Riesz basis.\n#\n# phi_j(z) = 1{j in NN_k(z)} assigns each point to its k nearest training centers.\n# TreatmentInteractionBasis then creates [D*psi(Z), (1-D)*psi(Z)],\n# which gives the standard NN-matching Riesz representer as a linear model.\n#\n# With cross-fitting the training fold becomes the centers, so the feature\n# dimension p = n_train >> n_test. The dual (Woodbury) solve handles this\n# in O(n_test^3) instead of O(n_train^3).\n\nbasis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)\nphi_knn = TreatmentInteractionBasis(base_basis=basis_knn)\n\nres_knn = grr_ate(\n X=X,\n Y=Y,\n basis=phi_knn,\n generator=gen,\n cross_fit=True,\n folds=5,\n random_state=0,\n estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n outcome_models=\"shared\",\n riesz_penalty=\"l2\",\n riesz_lam=1e-3,\n max_iter=300,\n tol=1e-8,\n)\n\nprint(res_knn.summary_text())\n" }, { "cell_type": "markdown", "id": "e278c8e4", "metadata": {}, "source": [ "## Example 4: Random forest leaf basis (optional)\n", "\n", "If you have `scikit-learn` installed, you can use a random forest as a **feature map**\n", "via leaf indicators. This keeps GRR convex (linear in parameters) while giving a\n", "flexible nonparametric basis.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "a84820a2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.786895849078183, ess_treated=350.8795988622065, ess_control=507.56068177539834\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.03187 0.0199079 [ 0.992854, 1.07089] 0\n", "RW 1.69371 0.252221 [ 1.19937, 2.18806] 1.88e-11\n", "ARW 1.25382 0.215261 [ 0.831919, 1.67573] 5.72e-09\n", "TMLE 1.063 0.217126 [ 0.63744, 1.48856] 9.79e-07\n" ] } ], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from genriesz.sklearn_basis import RandomForestLeafBasis\n", "\n", "rf = RandomForestRegressor(\n", " n_estimators=200,\n", " max_depth=6,\n", " random_state=0,\n", ")\n", "\n", "leaf_basis = RandomForestLeafBasis(rf, include_bias=True).fit(X, Y)\n", "\n", "res_rf = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=leaf_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rf.summary_text())" ] }, { "cell_type": "markdown", "id": "7111e1ae", "metadata": {}, "source": [ "## Example 5: Neural network embedding basis (optional)\n", "\n", "If you have PyTorch installed, you can use a neural network as a **basis function**.\n", "A recommended procedure is:\n", "\n", "1. train an embedding network on a separate task,\n", "2. use it as a basis function,\n", "3. use its outputs as features in GRR.\n", "\n", "Below we show the mechanics with a small MLP (training is optional).\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "6ba7c964", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.10074303607722811, ess_treated=1234.797658219504, ess_control=1289.906386252168\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.0526 0.00233879 [ 1.04801, 1.05718] 0\n", "RW 1.02999 0.0497571 [ 0.932472, 1.12752] 0\n", "ARW 1.02254 0.0372324 [ 0.949568, 1.09552] 0\n", "TMLE 1.02144 0.0372205 [ 0.948489, 1.09439] 0\n" ] } ], "source": [ "import torch\n", "from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis\n", "\n", "torch.manual_seed(0)\n", "\n", "net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)\n", "# (Optional) Train net here on a separate task.\n", "# For a lightweight demo, we skip training and just use the random initialization.\n", "nn_basis = TorchEmbeddingBasis(net, include_bias=True, device=\"cpu\")\n", "\n", "res_nn = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=nn_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_nn.summary_text())" ] }, { "cell_type": "markdown", "id": "57afc4f3", "metadata": {}, "source": [ "## Generator (SQ / UKL / BP)\n", "\n", "This section runs **SQ-Riesz**, **UKL-Riesz**, and **BP-Riesz** for the same\n", "polynomial interaction basis, and compares multiple regularization norms and\n", "strengths.\n", "\n", "We use a **branch function** for UKL/BP that forces:\n", "\n", "- positive branch for treated units (`D=1`),\n", "- negative branch for control units (`D=0`).\n", "\n", "All four estimators (**RA / RW / ARW / TMLE**) are reported." ] }, { "cell_type": "code", "execution_count": 6, "id": "fc758a37", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | generator | \n", "penalty | \n", "lam | \n", "p_norm | \n", "ra | \n", "ra_se | \n", "ra_err | \n", "rw | \n", "rw_se | \n", "rw_err | \n", "arw | \n", "arw_se | \n", "arw_err | \n", "tmle | \n", "tmle_se | \n", "tmle_err | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 11 | \n", "BP (omega=0.1, C=1) | \n", "l1 | \n", "0.0010 | \n", "1.5 | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.019822 | \n", "0.059940 | \n", "0.019822 | \n", "1.039039 | \n", "0.043709 | \n", "0.039039 | \n", "1.038188 | \n", "0.043721 | \n", "0.038188 | \n", "
| 15 | \n", "BP (omega=0.2, C=1) | \n", "l1 | \n", "0.0010 | \n", "1.5 | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.023228 | \n", "0.059248 | \n", "0.023228 | \n", "1.039091 | \n", "0.043380 | \n", "0.039091 | \n", "1.038355 | \n", "0.043391 | \n", "0.038355 | \n", "
| 14 | \n", "BP (omega=0.2, C=1) | \n", "l1 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.021144 | \n", "0.059442 | \n", "0.021144 | \n", "1.039292 | \n", "0.043518 | \n", "0.039292 | \n", "1.038505 | \n", "0.043530 | \n", "0.038505 | \n", "
| 13 | \n", "BP (omega=0.2, C=1) | \n", "l2 | \n", "0.0010 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.022101 | \n", "0.059411 | \n", "0.022101 | \n", "1.039302 | \n", "0.043490 | \n", "0.039302 | \n", "1.038521 | \n", "0.043502 | \n", "0.038521 | \n", "
| 12 | \n", "BP (omega=0.2, C=1) | \n", "l2 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.021037 | \n", "0.059462 | \n", "0.021037 | \n", "1.039308 | \n", "0.043530 | \n", "0.039308 | \n", "1.038517 | \n", "0.043542 | \n", "0.038517 | \n", "
| 7 | \n", "UKL (C=1) | \n", "l1 | \n", "0.0010 | \n", "1.5 | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.015917 | \n", "0.060709 | \n", "0.015917 | \n", "1.039322 | \n", "0.044135 | \n", "0.039322 | \n", "1.038294 | \n", "0.044140 | \n", "0.038294 | \n", "
| 19 | \n", "BP (omega=0.5, C=1) | \n", "l1 | \n", "0.0010 | \n", "1.5 | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.029973 | \n", "0.057821 | \n", "0.029973 | \n", "1.039372 | \n", "0.042636 | \n", "0.039372 | \n", "1.038849 | \n", "0.042643 | \n", "0.038849 | \n", "
| 18 | \n", "BP (omega=0.5, C=1) | \n", "l1 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.027756 | \n", "0.058034 | \n", "0.027756 | \n", "1.039449 | \n", "0.042746 | \n", "0.039449 | \n", "1.038895 | \n", "0.042754 | \n", "0.038895 | \n", "
| 10 | \n", "BP (omega=0.1, C=1) | \n", "l1 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.018222 | \n", "0.060047 | \n", "0.018222 | \n", "1.039456 | \n", "0.043853 | \n", "0.039456 | \n", "1.038531 | \n", "0.043864 | \n", "0.038531 | \n", "
| 16 | \n", "BP (omega=0.5, C=1) | \n", "l2 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.027687 | \n", "0.058050 | \n", "0.027687 | \n", "1.039457 | \n", "0.042753 | \n", "0.039457 | \n", "1.038900 | \n", "0.042761 | \n", "0.038900 | \n", "
| 17 | \n", "BP (omega=0.5, C=1) | \n", "l2 | \n", "0.0010 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.029078 | \n", "0.057985 | \n", "0.029078 | \n", "1.039464 | \n", "0.042703 | \n", "0.039464 | \n", "1.038918 | \n", "0.042711 | \n", "0.038918 | \n", "
| 9 | \n", "BP (omega=0.1, C=1) | \n", "l2 | \n", "0.0010 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.019118 | \n", "0.060007 | \n", "0.019118 | \n", "1.039485 | \n", "0.043828 | \n", "0.039485 | \n", "1.038564 | \n", "0.043839 | \n", "0.038564 | \n", "
| 8 | \n", "BP (omega=0.1, C=1) | \n", "l2 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.018151 | \n", "0.060053 | \n", "0.018151 | \n", "1.039499 | \n", "0.043864 | \n", "0.039499 | \n", "1.038568 | \n", "0.043875 | \n", "0.038568 | \n", "
| 6 | \n", "UKL (C=1) | \n", "l1 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.014454 | \n", "0.060770 | \n", "0.014454 | \n", "1.039918 | \n", "0.044285 | \n", "0.039918 | \n", "1.038785 | \n", "0.044288 | \n", "0.038785 | \n", "
| 5 | \n", "UKL (C=1) | \n", "l2 | \n", "0.0010 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.015266 | \n", "0.060731 | \n", "0.015266 | \n", "1.039954 | \n", "0.044264 | \n", "0.039954 | \n", "1.038824 | \n", "0.044267 | \n", "0.038824 | \n", "
| 4 | \n", "UKL (C=1) | \n", "l2 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.014382 | \n", "0.060769 | \n", "0.014382 | \n", "1.039978 | \n", "0.044297 | \n", "0.039978 | \n", "1.038835 | \n", "0.044299 | \n", "0.038835 | \n", "
| 3 | \n", "SQ | \n", "l1 | \n", "0.0010 | \n", "1.5 | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.033380 | \n", "0.056470 | \n", "0.033380 | \n", "1.041214 | \n", "0.041673 | \n", "0.041214 | \n", "1.040825 | \n", "0.041676 | \n", "0.040825 | \n", "
| 2 | \n", "SQ | \n", "l1 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.033103 | \n", "0.056626 | \n", "0.033103 | \n", "1.041513 | \n", "0.041770 | \n", "0.041513 | \n", "1.041085 | \n", "0.041772 | \n", "0.041085 | \n", "
| 0 | \n", "SQ | \n", "l2 | \n", "0.0001 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.032923 | \n", "0.056626 | \n", "0.032923 | \n", "1.041555 | \n", "0.041760 | \n", "0.041555 | \n", "1.041128 | \n", "0.041763 | \n", "0.041128 | \n", "
| 1 | \n", "SQ | \n", "l2 | \n", "0.0010 | \n", "NaN | \n", "1.027967 | \n", "0.003809 | \n", "0.027967 | \n", "1.031386 | \n", "0.056477 | \n", "0.031386 | \n", "1.041641 | \n", "0.041583 | \n", "0.041641 | \n", "1.041266 | \n", "0.041585 | \n", "0.041266 | \n", "