{ "cells": [ { "cell_type": "markdown", "id": "47bd7f90", "metadata": {}, "source": [ "# ATT simulation with (approximate) true value (genriesz)\n", "\n", "This notebook demonstrates **ATT** estimation:\n", "\n", "$$\n", "\\theta = \\mathbb{E}[Y(1)-Y(0) \\mid D=1].\n", "$$\n", "\n", "We generate a synthetic population with **heterogeneous treatment effects**, so\n", "in general **ATT \u2260 ATE**. We compute an approximate \"true\" ATT by Monte Carlo\n", "from a large simulated population, and compare it to GRR-based estimators.\n", "\n", "We assume the regressor has the form $X = [D, Z...]$ where $D$ is a binary\n", "treatment indicator.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "0b460663", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from genriesz import (\n", " grr_att,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " BPGenerator,\n", " PolynomialBasis,\n", " TreatmentInteractionBasis,\n", " RBFRandomFourierBasis,\n", " KNNCatchmentBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)" ] }, { "cell_type": "markdown", "id": "1a4d1944", "metadata": {}, "source": [ "## Data generating process" ] }, { "cell_type": "code", "execution_count": 2, "id": "b9c6f8d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Approx. true ATT (Monte Carlo): 1.158115237841178\n", "Approx. true ATE (Monte Carlo): 1.0038058584371061\n" ] } ], "source": [ "def draw_population(n: int, d_z: int, seed: int = 0):\n", " rng = np.random.default_rng(seed)\n", " Z = rng.normal(size=(n, d_z))\n", "\n", " logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]\n", " e = 1.0 / (1.0 + np.exp(-logits))\n", " D = rng.binomial(1, e, size=n).astype(int)\n", "\n", " # Heterogeneous treatment effect\n", " tau = 1.0 + 0.5 * Z[:, 0]\n", " mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2\n", "\n", " Y0 = mu0 + rng.normal(scale=1.0, size=n)\n", " Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)\n", " Y = D * Y1 + (1 - D) * Y0\n", "\n", " X = np.column_stack([D.astype(float), Z])\n", " return X, Y, Y0, Y1, D, tau\n", "\n", "# Large population for an approximate truth\n", "X_pop, Y_pop, Y0_pop, Y1_pop, D_pop, tau_pop = draw_population(n=200_000, d_z=5, seed=1)\n", "true_att = float(np.mean((Y1_pop - Y0_pop)[D_pop == 1]))\n", "true_ate = float(np.mean(Y1_pop - Y0_pop))\n", "\n", "print(\"Approx. true ATT (Monte Carlo):\", true_att)\n", "print(\"Approx. true ATE (Monte Carlo):\", true_ate)\n" ] }, { "cell_type": "markdown", "id": "21f52f87", "metadata": {}, "source": [ "## Example 1: Polynomial basis + treatment interactions" ] }, { "cell_type": "code", "execution_count": 3, "id": "841088ca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATT estimates (n=5000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.6705657462527745, max_abs_smd_weighted=0.005400089251320891, ess_treated=2500.896209806344, ess_control=1558.1240998523388\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.09227 0.0179759 [ 1.05703, 1.1275] 0\n", "RW 1.08225 0.0522539 [ 0.979834, 1.18467] 0\n", "ARW 1.09178 0.0377112 [ 1.01786, 1.16569] 0\n", "TMLE 1.09178 0.0377068 [ 1.01788, 1.16569] 0\n" ] } ], "source": [ "# Sample a dataset from the same DGP\n", "X, Y, Y0, Y1, D, tau = draw_population(n=5000, d_z=5, seed=0)\n", "\n", "# Basis on Z, then interact with D (works well for treatment-effect functionals)\n", "psi = PolynomialBasis(degree=2, include_bias=True)\n", "phi = TreatmentInteractionBasis(base_basis=psi)\n", "\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res.summary_text())\n" ] }, { "cell_type": "markdown", "id": "f2607309-d1ed-9c58-4132-c24e62df-0e8", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)\n", "\n", "This approximates an RBF kernel feature map using random Fourier features, then\n", "interacts the features with treatment." ] }, { "cell_type": "code", "execution_count": null, "id": "2787a9d9-2f22-49dc-4427-a073aaf6-4fd", "metadata": {}, "outputs": [], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)\n", "\n", "res_rff = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rff.summary_text())" ] }, { "cell_type": "markdown", "id": "88d99f52-437e-fb85-e106-60d4384c-225", "metadata": {}, "source": [ "## Example 3: KNN catchment basis (nearest-neighbor matching)\n", "\n", "Nearest-neighbor matching as a special case of squared-loss Riesz regression.\n", "``TreatmentInteractionBasis`` creates ``[D\u00b7\u03c8(Z), (1-D)\u00b7\u03c8(Z)]``, recovering the\n", "standard NN-matching Riesz representer as a linear model." ] }, { "cell_type": "code", "execution_count": null, "id": "2d317509-d583-946b-f348-d35494fa-e9b", "metadata": {}, "outputs": [], "source": [ "basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)\n", "phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)\n", "\n", "res_knn = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi_knn,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_knn.summary_text())" ] }, { "cell_type": "markdown", "id": "beef6c9d-1689-f5ca-2532-6e0b8cbf-5b5", "metadata": {}, "source": [ "## Example 4: Random forest leaf basis (optional)\n", "\n", "If you have ``scikit-learn`` installed, you can use a random forest as a **feature map**\n", "via leaf indicators. This keeps GRR convex while giving a flexible nonparametric basis." ] }, { "cell_type": "code", "execution_count": null, "id": "b01204fb-ef83-e95d-0483-0733d4b7-acc", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from genriesz.sklearn_basis import RandomForestLeafBasis\n", "\n", "rf = RandomForestRegressor(\n", " n_estimators=200,\n", " max_depth=6,\n", " random_state=0,\n", ")\n", "\n", "leaf_basis = RandomForestLeafBasis(rf).fit(X, Y)\n", "\n", "res_rf = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=leaf_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rf.summary_text())" ] }, { "cell_type": "markdown", "id": "e47c0730-9653-89a6-3a3b-7c51f190-dab", "metadata": {}, "source": [ "## Example 5: Neural network embedding basis (optional)\n", "\n", "If you have PyTorch installed, you can use a small MLP as a **basis function**.\n", "Below we use the random initialization for a lightweight demo." ] }, { "cell_type": "code", "execution_count": null, "id": "5bc98a6b-0929-665d-3b0a-f51ca65f-629", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis\n", "\n", "torch.manual_seed(0)\n", "\n", "net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)\n", "nn_basis = TorchEmbeddingBasis(net, include_bias=True, device=\"cpu\")\n", "\n", "res_nn = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=nn_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_nn.summary_text())" ] }, { "cell_type": "markdown", "id": "25cae85a", "metadata": {}, "source": [ "## Generator / regularization sweep (SQ / UKL / BP)\n", "\n", "We compare SQ-Riesz / UKL-Riesz / BP-Riesz under multiple regularization norms\n", "and strengths. For **UKL/BP**, we use a **branch function** that forces:\n", "\n", "- positive branch for treated units ($D=1$),\n", "- negative branch for control units ($D=0$),\n", "\n", "which matches the sign structure of common treatment-effect Riesz representers.\n", "\n", "We report **RA / RW / ARW / TMLE** and compare errors to the Monte Carlo \"true\" ATT." ] }, { "cell_type": "code", "execution_count": 6, "id": "3291c25d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | generator | \n", "penalty | \n", "lam | \n", "ra | \n", "ra_se | \n", "ra_err | \n", "rw | \n", "rw_se | \n", "rw_err | \n", "arw | \n", "arw_se | \n", "arw_err | \n", "tmle | \n", "tmle_se | \n", "tmle_err | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 11 | \n", "BP (omega=0.1, C=1) | \n", "l1 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.055435e+00 | \n", "6.848838e-02 | \n", "-1.026806e-01 | \n", "1.131193e+00 | \n", "4.746753e-02 | \n", "-2.692224e-02 | \n", "9.669026e+50 | \n", "9.669025e+50 | \n", "9.669026e+50 | \n", "
| 8 | \n", "BP (omega=0.1, C=1) | \n", "l2 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "9.667499e-01 | \n", "1.490316e-01 | \n", "-1.913653e-01 | \n", "1.121730e+00 | \n", "1.250208e-01 | \n", "-3.638561e-02 | \n", "1.093675e+00 | \n", "1.223439e-01 | \n", "-6.444051e-02 | \n", "
| 12 | \n", "BP (omega=0.2, C=1) | \n", "l2 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "9.323105e-01 | \n", "1.239197e-01 | \n", "-2.258047e-01 | \n", "1.103421e+00 | \n", "9.391140e-02 | \n", "-5.469465e-02 | \n", "1.091852e+00 | \n", "9.016083e-02 | \n", "-6.626373e-02 | \n", "
| 16 | \n", "BP (omega=0.5, C=1) | \n", "l2 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "9.934959e-01 | \n", "8.863373e-02 | \n", "-1.646193e-01 | \n", "1.098363e+00 | \n", "7.062806e-02 | \n", "-5.975265e-02 | \n", "1.093084e+00 | \n", "7.047579e-02 | \n", "-6.503108e-02 | \n", "
| 17 | \n", "BP (omega=0.5, C=1) | \n", "l2 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.115423e+00 | \n", "5.939799e-02 | \n", "-4.269203e-02 | \n", "1.090700e+00 | \n", "4.538244e-02 | \n", "-6.741474e-02 | \n", "1.090902e+00 | \n", "4.537922e-02 | \n", "-6.721304e-02 | \n", "
| 5 | \n", "UKL (C=1) | \n", "l2 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "9.929128e-01 | \n", "8.730633e-02 | \n", "-1.652024e-01 | \n", "1.082664e+00 | \n", "6.354118e-02 | \n", "-7.545118e-02 | \n", "1.088462e+00 | \n", "6.352715e-02 | \n", "-6.965350e-02 | \n", "
| 3 | \n", "SQ | \n", "l1 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.074762e+00 | \n", "5.261753e-02 | \n", "-8.335342e-02 | \n", "1.077929e+00 | \n", "3.807590e-02 | \n", "-8.018587e-02 | \n", "1.078362e+00 | \n", "3.795229e-02 | \n", "-7.975364e-02 | \n", "
| 13 | \n", "BP (omega=0.2, C=1) | \n", "l2 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.079259e+00 | \n", "6.341705e-02 | \n", "-7.885605e-02 | \n", "1.077821e+00 | \n", "4.803699e-02 | \n", "-8.029404e-02 | \n", "1.083387e+00 | \n", "4.796557e-02 | \n", "-7.472860e-02 | \n", "
| 2 | \n", "SQ | \n", "l1 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.075152e+00 | \n", "5.269233e-02 | \n", "-8.296282e-02 | \n", "1.077776e+00 | \n", "3.812085e-02 | \n", "-8.033969e-02 | \n", "1.078236e+00 | \n", "3.799601e-02 | \n", "-7.987879e-02 | \n", "
| 1 | \n", "SQ | \n", "l2 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.073219e+00 | \n", "5.256672e-02 | \n", "-8.489596e-02 | \n", "1.077763e+00 | \n", "3.800344e-02 | \n", "-8.035199e-02 | \n", "1.078169e+00 | \n", "3.787773e-02 | \n", "-7.994649e-02 | \n", "
| 0 | \n", "SQ | \n", "l2 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.075010e+00 | \n", "5.268673e-02 | \n", "-8.310482e-02 | \n", "1.077762e+00 | \n", "3.811353e-02 | \n", "-8.035305e-02 | \n", "1.078220e+00 | \n", "3.798852e-02 | \n", "-7.989492e-02 | \n", "
| 9 | \n", "BP (omega=0.1, C=1) | \n", "l2 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "1.133463e+00 | \n", "6.554599e-02 | \n", "-2.465228e-02 | \n", "1.077047e+00 | \n", "4.953755e-02 | \n", "-8.106850e-02 | \n", "1.083433e+00 | \n", "4.940757e-02 | \n", "-7.468217e-02 | \n", "
| 4 | \n", "UKL (C=1) | \n", "l2 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-4.913791e-01 | \n", "8.986323e-01 | \n", "-1.649494e+00 | \n", "1.272292e+00 | \n", "6.737185e-01 | \n", "1.141767e-01 | \n", "1.094470e+00 | \n", "6.736889e-01 | \n", "-6.364497e-02 | \n", "
| 18 | \n", "BP (omega=0.5, C=1) | \n", "l1 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-1.307005e+17 | \n", "9.786285e+16 | \n", "-1.307005e+17 | \n", "-8.159928e+16 | \n", "6.899664e+16 | \n", "-8.159928e+16 | \n", "1.088943e+00 | \n", "1.309872e+16 | \n", "-6.917228e-02 | \n", "
| 19 | \n", "BP (omega=0.5, C=1) | \n", "l1 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-1.316292e+18 | \n", "1.119353e+18 | \n", "-1.316292e+18 | \n", "7.034867e+17 | \n", "7.937043e+17 | \n", "7.034867e+17 | \n", "1.091431e+00 | \n", "1.548131e+17 | \n", "-6.668454e-02 | \n", "
| 14 | \n", "BP (omega=0.2, C=1) | \n", "l1 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-5.591832e+28 | \n", "4.945093e+28 | \n", "-5.591832e+28 | \n", "4.957689e+27 | \n", "8.465837e+27 | \n", "4.957689e+27 | \n", "1.091323e+00 | \n", "4.258426e+27 | \n", "-6.679187e-02 | \n", "
| 15 | \n", "BP (omega=0.2, C=1) | \n", "l1 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-1.359638e+35 | \n", "1.350624e+35 | \n", "-1.359638e+35 | \n", "-5.928117e+34 | \n", "6.163495e+34 | \n", "-5.928117e+34 | \n", "1.091295e+00 | \n", "3.255808e+33 | \n", "-6.681974e-02 | \n", "
| 10 | \n", "BP (omega=0.1, C=1) | \n", "l1 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "5.327378e+53 | \n", "5.327359e+53 | \n", "5.327378e+53 | \n", "6.551305e+53 | \n", "6.551264e+53 | \n", "6.551305e+53 | \n", "1.567020e+01 | \n", "5.749404e+48 | \n", "1.451209e+01 | \n", "
| 7 | \n", "UKL (C=1) | \n", "l1 | \n", "0.0010 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-1.987935e+301 | \n", "NaN | \n", "-1.987935e+301 | \n", "-2.759111e+300 | \n", "NaN | \n", "-2.759111e+300 | \n", "1.091298e+00 | \n", "NaN | \n", "-6.681756e-02 | \n", "
| 6 | \n", "UKL (C=1) | \n", "l1 | \n", "0.0001 | \n", "1.091298 | \n", "0.018015 | \n", "-0.066818 | \n", "-2.664446e+301 | \n", "NaN | \n", "-2.664446e+301 | \n", "-1.057640e+301 | \n", "NaN | \n", "-1.057640e+301 | \n", "1.091298e+00 | \n", "NaN | \n", "-6.681756e-02 | \n", "