{ "cells": [ { "cell_type": "markdown", "id": "47bd7f90", "metadata": {}, "source": [ "# ATT simulation with (approximate) true value (genriesz)\n", "\n", "This notebook demonstrates **ATT** estimation:\n", "\n", "$$\n", "\\theta = \\mathbb{E}[Y(1)-Y(0) \\mid D=1].\n", "$$\n", "\n", "We generate a synthetic population with **heterogeneous treatment effects**, so\n", "in general **ATT \u2260 ATE**. We compute an approximate \"true\" ATT by Monte Carlo\n", "from a large simulated population, and compare it to GRR-based estimators.\n", "\n", "We assume the regressor has the form $X = [D, Z...]$ where $D$ is a binary\n", "treatment indicator.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "0b460663", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from genriesz import (\n", " grr_att,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " BPGenerator,\n", " PolynomialBasis,\n", " TreatmentInteractionBasis,\n", " RBFRandomFourierBasis,\n", " KNNCatchmentBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)" ] }, { "cell_type": "markdown", "id": "1a4d1944", "metadata": {}, "source": [ "## Data generating process" ] }, { "cell_type": "code", "execution_count": 2, "id": "b9c6f8d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Approx. true ATT (Monte Carlo): 1.158115237841178\n", "Approx. true ATE (Monte Carlo): 1.0038058584371061\n" ] } ], "source": [ "def draw_population(n: int, d_z: int, seed: int = 0):\n", " rng = np.random.default_rng(seed)\n", " Z = rng.normal(size=(n, d_z))\n", "\n", " logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]\n", " e = 1.0 / (1.0 + np.exp(-logits))\n", " D = rng.binomial(1, e, size=n).astype(int)\n", "\n", " # Heterogeneous treatment effect\n", " tau = 1.0 + 0.5 * Z[:, 0]\n", " mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2\n", "\n", " Y0 = mu0 + rng.normal(scale=1.0, size=n)\n", " Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)\n", " Y = D * Y1 + (1 - D) * Y0\n", "\n", " X = np.column_stack([D.astype(float), Z])\n", " return X, Y, Y0, Y1, D, tau\n", "\n", "# Large population for an approximate truth\n", "X_pop, Y_pop, Y0_pop, Y1_pop, D_pop, tau_pop = draw_population(n=200_000, d_z=5, seed=1)\n", "true_att = float(np.mean((Y1_pop - Y0_pop)[D_pop == 1]))\n", "true_ate = float(np.mean(Y1_pop - Y0_pop))\n", "\n", "print(\"Approx. true ATT (Monte Carlo):\", true_att)\n", "print(\"Approx. true ATE (Monte Carlo):\", true_ate)\n" ] }, { "cell_type": "markdown", "id": "21f52f87", "metadata": {}, "source": [ "## Example 1: Polynomial basis + treatment interactions" ] }, { "cell_type": "code", "execution_count": 3, "id": "841088ca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATT estimates (n=5000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.6705657462527745, max_abs_smd_weighted=0.005400089251320891, ess_treated=2500.896209806344, ess_control=1558.1240998523388\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.09227 0.0179759 [ 1.05703, 1.1275] 0\n", "RW 1.08225 0.0522539 [ 0.979834, 1.18467] 0\n", "ARW 1.09178 0.0377112 [ 1.01786, 1.16569] 0\n", "TMLE 1.09178 0.0377068 [ 1.01788, 1.16569] 0\n" ] } ], "source": [ "# Sample a dataset from the same DGP\n", "X, Y, Y0, Y1, D, tau = draw_population(n=5000, d_z=5, seed=0)\n", "\n", "# Basis on Z, then interact with D (works well for treatment-effect functionals)\n", "psi = PolynomialBasis(degree=2, include_bias=True)\n", "phi = TreatmentInteractionBasis(base_basis=psi)\n", "\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res.summary_text())\n" ] }, { "cell_type": "markdown", "id": "f2607309-d1ed-9c58-4132-c24e62df-0e8", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)\n", "\n", "This approximates an RBF kernel feature map using random Fourier features, then\n", "interacts the features with treatment." ] }, { "cell_type": "code", "execution_count": null, "id": "2787a9d9-2f22-49dc-4427-a073aaf6-4fd", "metadata": {}, "outputs": [], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)\n", "\n", "res_rff = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rff.summary_text())" ] }, { "cell_type": "markdown", "id": "88d99f52-437e-fb85-e106-60d4384c-225", "metadata": {}, "source": [ "## Example 3: KNN catchment basis (nearest-neighbor matching)\n", "\n", "Nearest-neighbor matching as a special case of squared-loss Riesz regression.\n", "``TreatmentInteractionBasis`` creates ``[D\u00b7\u03c8(Z), (1-D)\u00b7\u03c8(Z)]``, recovering the\n", "standard NN-matching Riesz representer as a linear model." ] }, { "cell_type": "code", "execution_count": null, "id": "2d317509-d583-946b-f348-d35494fa-e9b", "metadata": {}, "outputs": [], "source": [ "basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)\n", "phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)\n", "\n", "res_knn = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi_knn,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_knn.summary_text())" ] }, { "cell_type": "markdown", "id": "beef6c9d-1689-f5ca-2532-6e0b8cbf-5b5", "metadata": {}, "source": [ "## Example 4: Random forest leaf basis (optional)\n", "\n", "If you have ``scikit-learn`` installed, you can use a random forest as a **feature map**\n", "via leaf indicators. This keeps GRR convex while giving a flexible nonparametric basis." ] }, { "cell_type": "code", "execution_count": null, "id": "b01204fb-ef83-e95d-0483-0733d4b7-acc", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from genriesz.sklearn_basis import RandomForestLeafBasis\n", "\n", "rf = RandomForestRegressor(\n", " n_estimators=200,\n", " max_depth=6,\n", " random_state=0,\n", ")\n", "\n", "leaf_basis = RandomForestLeafBasis(rf).fit(X, Y)\n", "\n", "res_rf = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=leaf_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rf.summary_text())" ] }, { "cell_type": "markdown", "id": "e47c0730-9653-89a6-3a3b-7c51f190-dab", "metadata": {}, "source": [ "## Example 5: Neural network embedding basis (optional)\n", "\n", "If you have PyTorch installed, you can use a small MLP as a **basis function**.\n", "Below we use the random initialization for a lightweight demo." ] }, { "cell_type": "code", "execution_count": null, "id": "5bc98a6b-0929-665d-3b0a-f51ca65f-629", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis\n", "\n", "torch.manual_seed(0)\n", "\n", "net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)\n", "nn_basis = TorchEmbeddingBasis(net, include_bias=True, device=\"cpu\")\n", "\n", "res_nn = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=nn_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_nn.summary_text())" ] }, { "cell_type": "markdown", "id": "25cae85a", "metadata": {}, "source": [ "## Generator / regularization sweep (SQ / UKL / BP)\n", "\n", "We compare SQ-Riesz / UKL-Riesz / BP-Riesz under multiple regularization norms\n", "and strengths. For **UKL/BP**, we use a **branch function** that forces:\n", "\n", "- positive branch for treated units ($D=1$),\n", "- negative branch for control units ($D=0$),\n", "\n", "which matches the sign structure of common treatment-effect Riesz representers.\n", "\n", "We report **RA / RW / ARW / TMLE** and compare errors to the Monte Carlo \"true\" ATT." ] }, { "cell_type": "code", "execution_count": 6, "id": "3291c25d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
generatorpenaltylamrara_sera_errrwrw_serw_errarwarw_searw_errtmletmle_setmle_err
11BP (omega=0.1, C=1)l10.00101.0912980.018015-0.0668181.055435e+006.848838e-02-1.026806e-011.131193e+004.746753e-02-2.692224e-029.669026e+509.669025e+509.669026e+50
8BP (omega=0.1, C=1)l20.00011.0912980.018015-0.0668189.667499e-011.490316e-01-1.913653e-011.121730e+001.250208e-01-3.638561e-021.093675e+001.223439e-01-6.444051e-02
12BP (omega=0.2, C=1)l20.00011.0912980.018015-0.0668189.323105e-011.239197e-01-2.258047e-011.103421e+009.391140e-02-5.469465e-021.091852e+009.016083e-02-6.626373e-02
16BP (omega=0.5, C=1)l20.00011.0912980.018015-0.0668189.934959e-018.863373e-02-1.646193e-011.098363e+007.062806e-02-5.975265e-021.093084e+007.047579e-02-6.503108e-02
17BP (omega=0.5, C=1)l20.00101.0912980.018015-0.0668181.115423e+005.939799e-02-4.269203e-021.090700e+004.538244e-02-6.741474e-021.090902e+004.537922e-02-6.721304e-02
5UKL (C=1)l20.00101.0912980.018015-0.0668189.929128e-018.730633e-02-1.652024e-011.082664e+006.354118e-02-7.545118e-021.088462e+006.352715e-02-6.965350e-02
3SQl10.00101.0912980.018015-0.0668181.074762e+005.261753e-02-8.335342e-021.077929e+003.807590e-02-8.018587e-021.078362e+003.795229e-02-7.975364e-02
13BP (omega=0.2, C=1)l20.00101.0912980.018015-0.0668181.079259e+006.341705e-02-7.885605e-021.077821e+004.803699e-02-8.029404e-021.083387e+004.796557e-02-7.472860e-02
2SQl10.00011.0912980.018015-0.0668181.075152e+005.269233e-02-8.296282e-021.077776e+003.812085e-02-8.033969e-021.078236e+003.799601e-02-7.987879e-02
1SQl20.00101.0912980.018015-0.0668181.073219e+005.256672e-02-8.489596e-021.077763e+003.800344e-02-8.035199e-021.078169e+003.787773e-02-7.994649e-02
0SQl20.00011.0912980.018015-0.0668181.075010e+005.268673e-02-8.310482e-021.077762e+003.811353e-02-8.035305e-021.078220e+003.798852e-02-7.989492e-02
9BP (omega=0.1, C=1)l20.00101.0912980.018015-0.0668181.133463e+006.554599e-02-2.465228e-021.077047e+004.953755e-02-8.106850e-021.083433e+004.940757e-02-7.468217e-02
4UKL (C=1)l20.00011.0912980.018015-0.066818-4.913791e-018.986323e-01-1.649494e+001.272292e+006.737185e-011.141767e-011.094470e+006.736889e-01-6.364497e-02
18BP (omega=0.5, C=1)l10.00011.0912980.018015-0.066818-1.307005e+179.786285e+16-1.307005e+17-8.159928e+166.899664e+16-8.159928e+161.088943e+001.309872e+16-6.917228e-02
19BP (omega=0.5, C=1)l10.00101.0912980.018015-0.066818-1.316292e+181.119353e+18-1.316292e+187.034867e+177.937043e+177.034867e+171.091431e+001.548131e+17-6.668454e-02
14BP (omega=0.2, C=1)l10.00011.0912980.018015-0.066818-5.591832e+284.945093e+28-5.591832e+284.957689e+278.465837e+274.957689e+271.091323e+004.258426e+27-6.679187e-02
15BP (omega=0.2, C=1)l10.00101.0912980.018015-0.066818-1.359638e+351.350624e+35-1.359638e+35-5.928117e+346.163495e+34-5.928117e+341.091295e+003.255808e+33-6.681974e-02
10BP (omega=0.1, C=1)l10.00011.0912980.018015-0.0668185.327378e+535.327359e+535.327378e+536.551305e+536.551264e+536.551305e+531.567020e+015.749404e+481.451209e+01
7UKL (C=1)l10.00101.0912980.018015-0.066818-1.987935e+301NaN-1.987935e+301-2.759111e+300NaN-2.759111e+3001.091298e+00NaN-6.681756e-02
6UKL (C=1)l10.00011.0912980.018015-0.066818-2.664446e+301NaN-2.664446e+301-1.057640e+301NaN-1.057640e+3011.091298e+00NaN-6.681756e-02
\n", "
" ], "text/plain": [ " generator penalty lam ra ra_se ra_err \\\n", "11 BP (omega=0.1, C=1) l1 0.0010 1.091298 0.018015 -0.066818 \n", "8 BP (omega=0.1, C=1) l2 0.0001 1.091298 0.018015 -0.066818 \n", "12 BP (omega=0.2, C=1) l2 0.0001 1.091298 0.018015 -0.066818 \n", "16 BP (omega=0.5, C=1) l2 0.0001 1.091298 0.018015 -0.066818 \n", "17 BP (omega=0.5, C=1) l2 0.0010 1.091298 0.018015 -0.066818 \n", "5 UKL (C=1) l2 0.0010 1.091298 0.018015 -0.066818 \n", "3 SQ l1 0.0010 1.091298 0.018015 -0.066818 \n", "13 BP (omega=0.2, C=1) l2 0.0010 1.091298 0.018015 -0.066818 \n", "2 SQ l1 0.0001 1.091298 0.018015 -0.066818 \n", "1 SQ l2 0.0010 1.091298 0.018015 -0.066818 \n", "0 SQ l2 0.0001 1.091298 0.018015 -0.066818 \n", "9 BP (omega=0.1, C=1) l2 0.0010 1.091298 0.018015 -0.066818 \n", "4 UKL (C=1) l2 0.0001 1.091298 0.018015 -0.066818 \n", "18 BP (omega=0.5, C=1) l1 0.0001 1.091298 0.018015 -0.066818 \n", "19 BP (omega=0.5, C=1) l1 0.0010 1.091298 0.018015 -0.066818 \n", "14 BP (omega=0.2, C=1) l1 0.0001 1.091298 0.018015 -0.066818 \n", "15 BP (omega=0.2, C=1) l1 0.0010 1.091298 0.018015 -0.066818 \n", "10 BP (omega=0.1, C=1) l1 0.0001 1.091298 0.018015 -0.066818 \n", "7 UKL (C=1) l1 0.0010 1.091298 0.018015 -0.066818 \n", "6 UKL (C=1) l1 0.0001 1.091298 0.018015 -0.066818 \n", "\n", " rw rw_se rw_err arw arw_se \\\n", "11 1.055435e+00 6.848838e-02 -1.026806e-01 1.131193e+00 4.746753e-02 \n", "8 9.667499e-01 1.490316e-01 -1.913653e-01 1.121730e+00 1.250208e-01 \n", "12 9.323105e-01 1.239197e-01 -2.258047e-01 1.103421e+00 9.391140e-02 \n", "16 9.934959e-01 8.863373e-02 -1.646193e-01 1.098363e+00 7.062806e-02 \n", "17 1.115423e+00 5.939799e-02 -4.269203e-02 1.090700e+00 4.538244e-02 \n", "5 9.929128e-01 8.730633e-02 -1.652024e-01 1.082664e+00 6.354118e-02 \n", "3 1.074762e+00 5.261753e-02 -8.335342e-02 1.077929e+00 3.807590e-02 \n", "13 1.079259e+00 6.341705e-02 -7.885605e-02 1.077821e+00 4.803699e-02 \n", "2 1.075152e+00 5.269233e-02 -8.296282e-02 1.077776e+00 3.812085e-02 \n", "1 1.073219e+00 5.256672e-02 -8.489596e-02 1.077763e+00 3.800344e-02 \n", "0 1.075010e+00 5.268673e-02 -8.310482e-02 1.077762e+00 3.811353e-02 \n", "9 1.133463e+00 6.554599e-02 -2.465228e-02 1.077047e+00 4.953755e-02 \n", "4 -4.913791e-01 8.986323e-01 -1.649494e+00 1.272292e+00 6.737185e-01 \n", "18 -1.307005e+17 9.786285e+16 -1.307005e+17 -8.159928e+16 6.899664e+16 \n", "19 -1.316292e+18 1.119353e+18 -1.316292e+18 7.034867e+17 7.937043e+17 \n", "14 -5.591832e+28 4.945093e+28 -5.591832e+28 4.957689e+27 8.465837e+27 \n", "15 -1.359638e+35 1.350624e+35 -1.359638e+35 -5.928117e+34 6.163495e+34 \n", "10 5.327378e+53 5.327359e+53 5.327378e+53 6.551305e+53 6.551264e+53 \n", "7 -1.987935e+301 NaN -1.987935e+301 -2.759111e+300 NaN \n", "6 -2.664446e+301 NaN -2.664446e+301 -1.057640e+301 NaN \n", "\n", " arw_err tmle tmle_se tmle_err \n", "11 -2.692224e-02 9.669026e+50 9.669025e+50 9.669026e+50 \n", "8 -3.638561e-02 1.093675e+00 1.223439e-01 -6.444051e-02 \n", "12 -5.469465e-02 1.091852e+00 9.016083e-02 -6.626373e-02 \n", "16 -5.975265e-02 1.093084e+00 7.047579e-02 -6.503108e-02 \n", "17 -6.741474e-02 1.090902e+00 4.537922e-02 -6.721304e-02 \n", "5 -7.545118e-02 1.088462e+00 6.352715e-02 -6.965350e-02 \n", "3 -8.018587e-02 1.078362e+00 3.795229e-02 -7.975364e-02 \n", "13 -8.029404e-02 1.083387e+00 4.796557e-02 -7.472860e-02 \n", "2 -8.033969e-02 1.078236e+00 3.799601e-02 -7.987879e-02 \n", "1 -8.035199e-02 1.078169e+00 3.787773e-02 -7.994649e-02 \n", "0 -8.035305e-02 1.078220e+00 3.798852e-02 -7.989492e-02 \n", "9 -8.106850e-02 1.083433e+00 4.940757e-02 -7.468217e-02 \n", "4 1.141767e-01 1.094470e+00 6.736889e-01 -6.364497e-02 \n", "18 -8.159928e+16 1.088943e+00 1.309872e+16 -6.917228e-02 \n", "19 7.034867e+17 1.091431e+00 1.548131e+17 -6.668454e-02 \n", "14 4.957689e+27 1.091323e+00 4.258426e+27 -6.679187e-02 \n", "15 -5.928117e+34 1.091295e+00 3.255808e+33 -6.681974e-02 \n", "10 6.551305e+53 1.567020e+01 5.749404e+48 1.451209e+01 \n", "7 -2.759111e+300 1.091298e+00 NaN -6.681756e-02 \n", "6 -1.057640e+301 1.091298e+00 NaN -6.681756e-02 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Branch: + for treated, - for control (D is the first column of X).\n", "branch = lambda x: int(x[0] == 1.0)\n", "\n", "generator_grid = [\n", " (\"SQ\", SquaredGenerator(C=0.0).as_generator()),\n", " (\"UKL (C=1)\", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.1, C=1)\", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.2, C=1)\", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.5, C=1)\", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),\n", "]\n", "\n", "penalty_grid = [\n", " {\"penalty\": \"l2\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"l2\", \"lam\": 1e-3, \"p_norm\": None},\n", " {\"penalty\": \"l1\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"lp\", \"lam\": 1e-3, \"p_norm\": 1.5},\n", "]\n", "\n", "rows = []\n", "for gname, gen_i in generator_grid:\n", " for cfg in penalty_grid:\n", " res_i = grr_att(\n", " X=X,\n", " Y=Y,\n", " basis=phi,\n", " generator=gen_i,\n", " cross_fit=True,\n", " folds=3,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " outcome_link=\"identity\", # Y is unbounded\n", " riesz_penalty=cfg[\"penalty\"],\n", " riesz_lam=cfg[\"lam\"],\n", " riesz_p_norm=cfg.get(\"p_norm\"),\n", " max_iter=250,\n", " tol=1e-8,\n", " )\n", "\n", " row = {\n", " \"generator\": gname,\n", " \"penalty\": cfg[\"penalty\"],\n", " \"lam\": cfg[\"lam\"],\n", " }\n", "\n", " for k in (\"ra\", \"rw\", \"arw\", \"tmle\"):\n", " e = res_i.estimates[k]\n", " row[f\"{k}\"] = e.estimate\n", " row[f\"{k}_se\"] = e.se\n", " row[f\"{k}_err\"] = e.estimate - true_att\n", "\n", " rows.append(row)\n", "\n", "import pandas as pd\n", "\n", "df = pd.DataFrame(rows)\n", "df = df.sort_values(by=\"arw_err\", key=lambda s: np.abs(s))\n", "display(df)" ] }, { "cell_type": "code", "execution_count": null, "id": "d211c87c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.6" } }, "nbformat": 4, "nbformat_minor": 5 }