{ "cells": [ { "cell_type": "markdown", "id": "f58299f4", "metadata": {}, "source": [ "# ATE end-to-end examples (genriesz)\n", "\n", "This notebook demonstrates how to estimate the **Average Treatment Effect (ATE)** with **genriesz**.\n", "\n", "We assume the regressor has the form:\n", "\n", "- $X = (D, Z)$, where $D$ is a **binary treatment indicator** ($0$ or $1$),\n", "- $Y$ is the observed outcome.\n", "\n", "We will compute (optionally with cross-fitting):\n", "\n", "- **RA**: regression adjustment (plug-in)\n", "- **RW**: Riesz weighting (weighting only)\n", "- **ARW**: augmented Riesz weighting\n", "- **TMLE**: targeted maximum likelihood estimation (one-step fluctuation)\n", "\n", "We also show how to swap the **basis**:\n", "- polynomial features,\n", "- RKHS RBF random features,\n", "- nearest-neighbor matching (nearest-neighbor matching basis),\n", "- random forest leaf features (optional),\n", "- neural network embeddings (optional).\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "71628cec", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from genriesz import (\n", " grr_ate,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " PolynomialBasis,\n", " TreatmentInteractionBasis,\n", " RBFRandomFourierBasis,\n", " KNNCatchmentBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)\n" ] }, { "cell_type": "markdown", "id": "041e5044", "metadata": {}, "source": [ "## Synthetic data" ] }, { "cell_type": "code", "execution_count": 2, "id": "f744ecd2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X shape: (3000, 6) Y shape: (3000,)\n" ] } ], "source": [ "# Data-generating process\n", "n = 3000\n", "d_z = 5\n", "\n", "Z = rng.normal(size=(n, d_z))\n", "\n", "# Treatment assignment: e(Z) = sigmoid(a'Z)\n", "logits = 0.7 * Z[:, 0] - 0.3 * Z[:, 1]\n", "e = 1.0 / (1.0 + np.exp(-logits))\n", "D = rng.binomial(1, e, size=n).astype(float)\n", "\n", "# Potential outcomes (constant effect for simplicity)\n", "tau = 1.0\n", "mu0 = 0.5 * Z[:, 0] + 0.25 * Z[:, 1] ** 2\n", "Y0 = mu0 + rng.normal(scale=1.0, size=n)\n", "Y1 = mu0 + tau + rng.normal(scale=1.0, size=n)\n", "\n", "Y = D * Y1 + (1.0 - D) * Y0\n", "\n", "# Regressor matrix X = [D, Z...]\n", "X = np.column_stack([D, Z])\n", "\n", "print(\"X shape:\", X.shape, \"Y shape:\", Y.shape)\n" ] }, { "cell_type": "markdown", "id": "1277ad0c", "metadata": {}, "source": [ "## Example 1: Polynomial basis + treatment interactions" ] }, { "cell_type": "code", "execution_count": 3, "id": "31a5fe31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: alpha_abs_mean=2.0164751561831986, alpha_abs_p95=3.844795773570575, alpha_abs_max=9.334972319665178, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.01662007897774767, ess_treated=1188.4096325793182, ess_control=1260.0346792242883\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.03046 0.00359346 [ 1.02342, 1.03751] 0\n", "RW 1.03302 0.0569279 [ 0.921445, 1.1446] 0\n", "ARW 1.02916 0.0419132 [ 0.947017, 1.11131] 0\n", "TMLE 1.02921 0.0419128 [ 0.947063, 1.11136] 0\n" ] } ], "source": [ "# Basis on Z, then interact with D (ATE-friendly)\n", "psi = PolynomialBasis(degree=2, include_bias=True)\n", "phi = TreatmentInteractionBasis(base_basis=psi)\n", "\n", "# Generator: Squared loss (always safe / no domain constraints)\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res_poly = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=phi,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_poly.summary_text())\n" ] }, { "cell_type": "markdown", "id": "c848c3df", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)\n", "\n", "This approximates an RBF kernel feature map using random Fourier features, then\n", "interacts the features with treatment.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "e0c9683d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: alpha_abs_mean=2.077446519847265, alpha_abs_p95=3.353495372876411, alpha_abs_max=4.936266573928845, max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.27862992844919376, ess_treated=1363.470230442599, ess_control=1370.9575124421833\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.15729 0.00717955 [ 1.14321, 1.17136] 0\n", "RW 1.22541 0.0540494 [ 1.11947, 1.33134] 0\n", "ARW 1.12627 0.0423592 [ 1.04325, 1.20929] 0\n", "TMLE 1.12732 0.0423287 [ 1.04436, 1.21028] 0\n" ] } ], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)\n", "\n", "res_rff = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=phi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rff.summary_text())\n" ] }, { "cell_type": "markdown", "id": "109b8823", "metadata": {}, "source": "## Example 3: KNN catchment basis (nearest-neighbor matching)\n\nNearest-neighbor matching can be expressed using a **catchment-area basis**\n\n$$\n\\phi_j(z) = \\mathbf{1}\\{j \\in \\mathrm{NN}_k(z)\\},\n$$\n\nwhich assigns each point to its $k$ nearest training centers.\nThis is a special case of squared-loss Riesz regression, so we can pass it to\n`grr_ate` exactly like any other basis.\n" }, { "cell_type": "code", "execution_count": null, "id": "e50f0e43", "metadata": {}, "outputs": [], "source": "# KNN catchment basis as a nearest-neighbor Riesz basis.\n#\n# phi_j(z) = 1{j in NN_k(z)} assigns each point to its k nearest training centers.\n# TreatmentInteractionBasis then creates [D*psi(Z), (1-D)*psi(Z)],\n# which gives the standard NN-matching Riesz representer as a linear model.\n#\n# With cross-fitting the training fold becomes the centers, so the feature\n# dimension p = n_train >> n_test. The dual (Woodbury) solve handles this\n# in O(n_test^3) instead of O(n_train^3).\n\nbasis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)\nphi_knn = TreatmentInteractionBasis(base_basis=basis_knn)\n\nres_knn = grr_ate(\n X=X,\n Y=Y,\n basis=phi_knn,\n generator=gen,\n cross_fit=True,\n folds=5,\n random_state=0,\n estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n outcome_models=\"shared\",\n riesz_penalty=\"l2\",\n riesz_lam=1e-3,\n max_iter=300,\n tol=1e-8,\n)\n\nprint(res_knn.summary_text())\n" }, { "cell_type": "markdown", "id": "e278c8e4", "metadata": {}, "source": [ "## Example 4: Random forest leaf basis (optional)\n", "\n", "If you have `scikit-learn` installed, you can use a random forest as a **feature map**\n", "via leaf indicators. This keeps GRR convex (linear in parameters) while giving a\n", "flexible nonparametric basis.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "a84820a2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.786895849078183, ess_treated=350.8795988622065, ess_control=507.56068177539834\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.03187 0.0199079 [ 0.992854, 1.07089] 0\n", "RW 1.69371 0.252221 [ 1.19937, 2.18806] 1.88e-11\n", "ARW 1.25382 0.215261 [ 0.831919, 1.67573] 5.72e-09\n", "TMLE 1.063 0.217126 [ 0.63744, 1.48856] 9.79e-07\n" ] } ], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from genriesz.sklearn_basis import RandomForestLeafBasis\n", "\n", "rf = RandomForestRegressor(\n", " n_estimators=200,\n", " max_depth=6,\n", " random_state=0,\n", ")\n", "\n", "leaf_basis = RandomForestLeafBasis(rf, include_bias=True).fit(X, Y)\n", "\n", "res_rf = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=leaf_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rf.summary_text())" ] }, { "cell_type": "markdown", "id": "7111e1ae", "metadata": {}, "source": [ "## Example 5: Neural network embedding basis (optional)\n", "\n", "If you have PyTorch installed, you can use a neural network as a **basis function**.\n", "A recommended procedure is:\n", "\n", "1. train an embedding network on a separate task,\n", "2. use it as a basis function,\n", "3. use its outputs as features in GRR.\n", "\n", "Below we show the mechanics with a small MLP (training is optional).\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "6ba7c964", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ATE estimates (n=3000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.7095986621457625, max_abs_smd_weighted=0.10074303607722811, ess_treated=1234.797658219504, ess_control=1289.906386252168\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 1.0526 0.00233879 [ 1.04801, 1.05718] 0\n", "RW 1.02999 0.0497571 [ 0.932472, 1.12752] 0\n", "ARW 1.02254 0.0372324 [ 0.949568, 1.09552] 0\n", "TMLE 1.02144 0.0372205 [ 0.948489, 1.09439] 0\n" ] } ], "source": [ "import torch\n", "from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis\n", "\n", "torch.manual_seed(0)\n", "\n", "net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)\n", "# (Optional) Train net here on a separate task.\n", "# For a lightweight demo, we skip training and just use the random initialization.\n", "nn_basis = TorchEmbeddingBasis(net, include_bias=True, device=\"cpu\")\n", "\n", "res_nn = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=nn_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_nn.summary_text())" ] }, { "cell_type": "markdown", "id": "57afc4f3", "metadata": {}, "source": [ "## Generator (SQ / UKL / BP)\n", "\n", "This section runs **SQ-Riesz**, **UKL-Riesz**, and **BP-Riesz** for the same\n", "polynomial interaction basis, and compares multiple regularization norms and\n", "strengths.\n", "\n", "We use a **branch function** for UKL/BP that forces:\n", "\n", "- positive branch for treated units (`D=1`),\n", "- negative branch for control units (`D=0`).\n", "\n", "All four estimators (**RA / RW / ARW / TMLE**) are reported." ] }, { "cell_type": "code", "execution_count": 6, "id": "fc758a37", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
generatorpenaltylamp_normrara_sera_errrwrw_serw_errarwarw_searw_errtmletmle_setmle_err
11BP (omega=0.1, C=1)l10.00101.51.0279670.0038090.0279671.0198220.0599400.0198221.0390390.0437090.0390391.0381880.0437210.038188
15BP (omega=0.2, C=1)l10.00101.51.0279670.0038090.0279671.0232280.0592480.0232281.0390910.0433800.0390911.0383550.0433910.038355
14BP (omega=0.2, C=1)l10.0001NaN1.0279670.0038090.0279671.0211440.0594420.0211441.0392920.0435180.0392921.0385050.0435300.038505
13BP (omega=0.2, C=1)l20.0010NaN1.0279670.0038090.0279671.0221010.0594110.0221011.0393020.0434900.0393021.0385210.0435020.038521
12BP (omega=0.2, C=1)l20.0001NaN1.0279670.0038090.0279671.0210370.0594620.0210371.0393080.0435300.0393081.0385170.0435420.038517
7UKL (C=1)l10.00101.51.0279670.0038090.0279671.0159170.0607090.0159171.0393220.0441350.0393221.0382940.0441400.038294
19BP (omega=0.5, C=1)l10.00101.51.0279670.0038090.0279671.0299730.0578210.0299731.0393720.0426360.0393721.0388490.0426430.038849
18BP (omega=0.5, C=1)l10.0001NaN1.0279670.0038090.0279671.0277560.0580340.0277561.0394490.0427460.0394491.0388950.0427540.038895
10BP (omega=0.1, C=1)l10.0001NaN1.0279670.0038090.0279671.0182220.0600470.0182221.0394560.0438530.0394561.0385310.0438640.038531
16BP (omega=0.5, C=1)l20.0001NaN1.0279670.0038090.0279671.0276870.0580500.0276871.0394570.0427530.0394571.0389000.0427610.038900
17BP (omega=0.5, C=1)l20.0010NaN1.0279670.0038090.0279671.0290780.0579850.0290781.0394640.0427030.0394641.0389180.0427110.038918
9BP (omega=0.1, C=1)l20.0010NaN1.0279670.0038090.0279671.0191180.0600070.0191181.0394850.0438280.0394851.0385640.0438390.038564
8BP (omega=0.1, C=1)l20.0001NaN1.0279670.0038090.0279671.0181510.0600530.0181511.0394990.0438640.0394991.0385680.0438750.038568
6UKL (C=1)l10.0001NaN1.0279670.0038090.0279671.0144540.0607700.0144541.0399180.0442850.0399181.0387850.0442880.038785
5UKL (C=1)l20.0010NaN1.0279670.0038090.0279671.0152660.0607310.0152661.0399540.0442640.0399541.0388240.0442670.038824
4UKL (C=1)l20.0001NaN1.0279670.0038090.0279671.0143820.0607690.0143821.0399780.0442970.0399781.0388350.0442990.038835
3SQl10.00101.51.0279670.0038090.0279671.0333800.0564700.0333801.0412140.0416730.0412141.0408250.0416760.040825
2SQl10.0001NaN1.0279670.0038090.0279671.0331030.0566260.0331031.0415130.0417700.0415131.0410850.0417720.041085
0SQl20.0001NaN1.0279670.0038090.0279671.0329230.0566260.0329231.0415550.0417600.0415551.0411280.0417630.041128
1SQl20.0010NaN1.0279670.0038090.0279671.0313860.0564770.0313861.0416410.0415830.0416411.0412660.0415850.041266
\n", "
" ], "text/plain": [ " generator penalty lam p_norm ra ra_se ra_err \\\n", "11 BP (omega=0.1, C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 \n", "15 BP (omega=0.2, C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 \n", "14 BP (omega=0.2, C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 \n", "13 BP (omega=0.2, C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 \n", "12 BP (omega=0.2, C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 \n", "7 UKL (C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 \n", "19 BP (omega=0.5, C=1) l1 0.0010 1.5 1.027967 0.003809 0.027967 \n", "18 BP (omega=0.5, C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 \n", "10 BP (omega=0.1, C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 \n", "16 BP (omega=0.5, C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 \n", "17 BP (omega=0.5, C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 \n", "9 BP (omega=0.1, C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 \n", "8 BP (omega=0.1, C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 \n", "6 UKL (C=1) l1 0.0001 NaN 1.027967 0.003809 0.027967 \n", "5 UKL (C=1) l2 0.0010 NaN 1.027967 0.003809 0.027967 \n", "4 UKL (C=1) l2 0.0001 NaN 1.027967 0.003809 0.027967 \n", "3 SQ l1 0.0010 1.5 1.027967 0.003809 0.027967 \n", "2 SQ l1 0.0001 NaN 1.027967 0.003809 0.027967 \n", "0 SQ l2 0.0001 NaN 1.027967 0.003809 0.027967 \n", "1 SQ l2 0.0010 NaN 1.027967 0.003809 0.027967 \n", "\n", " rw rw_se rw_err arw arw_se arw_err tmle \\\n", "11 1.019822 0.059940 0.019822 1.039039 0.043709 0.039039 1.038188 \n", "15 1.023228 0.059248 0.023228 1.039091 0.043380 0.039091 1.038355 \n", "14 1.021144 0.059442 0.021144 1.039292 0.043518 0.039292 1.038505 \n", "13 1.022101 0.059411 0.022101 1.039302 0.043490 0.039302 1.038521 \n", "12 1.021037 0.059462 0.021037 1.039308 0.043530 0.039308 1.038517 \n", "7 1.015917 0.060709 0.015917 1.039322 0.044135 0.039322 1.038294 \n", "19 1.029973 0.057821 0.029973 1.039372 0.042636 0.039372 1.038849 \n", "18 1.027756 0.058034 0.027756 1.039449 0.042746 0.039449 1.038895 \n", "10 1.018222 0.060047 0.018222 1.039456 0.043853 0.039456 1.038531 \n", "16 1.027687 0.058050 0.027687 1.039457 0.042753 0.039457 1.038900 \n", "17 1.029078 0.057985 0.029078 1.039464 0.042703 0.039464 1.038918 \n", "9 1.019118 0.060007 0.019118 1.039485 0.043828 0.039485 1.038564 \n", "8 1.018151 0.060053 0.018151 1.039499 0.043864 0.039499 1.038568 \n", "6 1.014454 0.060770 0.014454 1.039918 0.044285 0.039918 1.038785 \n", "5 1.015266 0.060731 0.015266 1.039954 0.044264 0.039954 1.038824 \n", "4 1.014382 0.060769 0.014382 1.039978 0.044297 0.039978 1.038835 \n", "3 1.033380 0.056470 0.033380 1.041214 0.041673 0.041214 1.040825 \n", "2 1.033103 0.056626 0.033103 1.041513 0.041770 0.041513 1.041085 \n", "0 1.032923 0.056626 0.032923 1.041555 0.041760 0.041555 1.041128 \n", "1 1.031386 0.056477 0.031386 1.041641 0.041583 0.041641 1.041266 \n", "\n", " tmle_se tmle_err \n", "11 0.043721 0.038188 \n", "15 0.043391 0.038355 \n", "14 0.043530 0.038505 \n", "13 0.043502 0.038521 \n", "12 0.043542 0.038517 \n", "7 0.044140 0.038294 \n", "19 0.042643 0.038849 \n", "18 0.042754 0.038895 \n", "10 0.043864 0.038531 \n", "16 0.042761 0.038900 \n", "17 0.042711 0.038918 \n", "9 0.043839 0.038564 \n", "8 0.043875 0.038568 \n", "6 0.044288 0.038785 \n", "5 0.044267 0.038824 \n", "4 0.044299 0.038835 \n", "3 0.041676 0.040825 \n", "2 0.041772 0.041085 \n", "0 0.041763 0.041128 \n", "1 0.041585 0.041266 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from genriesz import BPGenerator\n", "\n", "branch = lambda x: int(x[0] == 1.0)\n", "\n", "generator_grid = [\n", " (\"SQ\", SquaredGenerator(C=0.0).as_generator()),\n", " (\"UKL (C=1)\", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.1, C=1)\", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.2, C=1)\", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.5, C=1)\", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),\n", "]\n", "\n", "penalty_grid = [\n", " {\"penalty\": \"l2\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"l2\", \"lam\": 1e-3, \"p_norm\": None},\n", " {\"penalty\": \"l1\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"lp\", \"lam\": 1e-3, \"p_norm\": 1.5},\n", "]\n", "\n", "rows = []\n", "for gname, gen_i in generator_grid:\n", " for cfg in penalty_grid:\n", " res_i = grr_ate(\n", " X=X,\n", " Y=Y,\n", " basis=phi,\n", " generator=gen_i,\n", " cross_fit=True,\n", " folds=3,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " outcome_link=\"identity\",\n", " riesz_penalty=cfg[\"penalty\"],\n", " riesz_lam=cfg[\"lam\"],\n", " riesz_p_norm=cfg[\"p_norm\"],\n", " max_iter=250,\n", " tol=1e-8,\n", " )\n", "\n", " row = {\n", " \"generator\": gname,\n", " \"penalty\": cfg[\"penalty\"],\n", " \"lam\": cfg[\"lam\"],\n", " \"p_norm\": cfg[\"p_norm\"],\n", " }\n", " for k in (\"ra\", \"rw\", \"arw\", \"tmle\"):\n", " e = res_i.estimates[k]\n", " row[f\"{k}\"] = e.estimate\n", " row[f\"{k}_se\"] = e.se\n", " row[f\"{k}_err\"] = e.estimate - tau # tau is the true constant effect in this DGP\n", " rows.append(row)\n", "\n", "import pandas as pd\n", "\n", "df = pd.DataFrame(rows)\n", "df = df.sort_values(by=\"arw_err\", key=lambda s: np.abs(s))\n", "display(df)" ] }, { "cell_type": "code", "execution_count": null, "id": "aab95cba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.6" } }, "nbformat": 4, "nbformat_minor": 5 }