{ "cells": [ { "cell_type": "markdown", "id": "d262ca93", "metadata": {}, "source": [ "# Panel DID simulation with true value (genriesz)\n", "\n", "We implement DID as **ATT on** the differenced outcome\n", "\n", "$$\n", "\\Delta Y = Y_1 - Y_0,\n", "$$\n", "\n", "where:\n", "\n", "- $Y0$ is the pre-period outcome,\n", "- $Y1$ is the post-period outcome,\n", "- the same units are observed in both periods (panel),\n", "- $D$ is a binary treatment indicator (treatment happens in the post period).\n", "\n", "With a standard panel DID setup:\n", "\n", "$$\n", "Y_{0} = \\mu(Z) + u + \\varepsilon_0, \\qquad\n", "Y_{1} = \\mu(Z) + \\text{trend}(Z) + u + \\tau D + \\varepsilon_1,\n", "$$\n", "\n", "the DID effect equals the constant treatment effect $\\tau$, provided the\n", "parallel trends condition holds after conditioning on $Z$.\n", "\n", "This notebook:\n", "\n", "1. simulates a large population to compute an approximate \"true\" DID effect,\n", "2. samples a dataset and calls `genriesz.grr_did(X, Y0=..., Y1=...)`.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "a7a4d81c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from genriesz import (\n", " grr_did,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " BPGenerator,\n", " PolynomialBasis,\n", " TreatmentInteractionBasis,\n", " RBFRandomFourierBasis,\n", " KNNCatchmentBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)" ] }, { "cell_type": "markdown", "id": "748b75c3", "metadata": {}, "source": [ "## DGP" ] }, { "cell_type": "code", "execution_count": 7, "id": "eaa7cf23", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True tau (by construction): 1.0\n", "Naive DID (difference in mean \u0394Y): 1.0477956893531037\n" ] } ], "source": [ "def draw_panel(n: int, d_z: int, tau: float, seed: int = 0):\n", " rng = np.random.default_rng(seed)\n", " Z = rng.normal(size=(n, d_z))\n", "\n", " logits = 0.6 * Z[:, 0] - 0.25 * Z[:, 1]\n", " e = 1.0 / (1.0 + np.exp(-logits))\n", " D = rng.binomial(1, e, size=n).astype(int)\n", "\n", " mu = 0.5 * Z[:, 0] - 0.2 * Z[:, 1] ** 2\n", " trend = 0.5 + 0.1 * Z[:, 0] # common trend that depends on Z\n", "\n", " u = rng.normal(scale=1.0, size=n) # unit fixed effect\n", "\n", " Y0 = mu + u + rng.normal(scale=1.0, size=n)\n", " Y1 = mu + trend + u + tau * D + rng.normal(scale=1.0, size=n)\n", "\n", " X = np.column_stack([D.astype(float), Z])\n", " return X, Y0, Y1, D\n", "\n", "tau_true = 1.0\n", "\n", "# Large population for an approximate truth\n", "X_pop, Y0_pop, Y1_pop, D_pop = draw_panel(n=200_000, d_z=5, tau=tau_true, seed=1)\n", "\n", "true_did = np.mean((Y1_pop - Y0_pop)[D_pop == 1]) - np.mean((Y1_pop - Y0_pop)[D_pop == 0]) # naive DID\n", "# Our target here is \"ATT on \u0394Y\", whose true value equals tau_true by construction.\n", "print(\"True tau (by construction):\", tau_true)\n", "print(\"Naive DID (difference in mean \u0394Y):\", true_did)\n" ] }, { "cell_type": "markdown", "id": "cfaf7418", "metadata": {}, "source": [ "## Example 1: Polynomial basis + treatment interactions" ] }, { "cell_type": "code", "execution_count": 9, "id": "fc6d817e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DID estimates (n=6000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.5468841195092563, max_abs_smd_weighted=0.002194706023289376, ess_treated=3017.871638373246, ess_control=2036.5721815918366\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 0.983907 0.0129607 [ 0.958505, 1.00931] 0\n", "RW 0.976091 0.0495386 [ 0.878997, 1.07318] 0\n", "ARW 0.98479 0.0421389 [ 0.9022, 1.06738] 0\n", "TMLE 0.984776 0.0421433 [ 0.902176, 1.06738] 0\n" ] } ], "source": [ "# Sample a dataset from the same DGP\n", "X, Y0, Y1, D = draw_panel(n=6000, d_z=5, tau=tau_true, seed=0)\n", "\n", "psi = PolynomialBasis(degree=2, include_bias=True)\n", "phi = TreatmentInteractionBasis(base_basis=psi)\n", "\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res.summary_text())\n" ] }, { "cell_type": "markdown", "id": "3309d26f-090f-da23-565a-d314e376-7b1", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)" ] }, { "cell_type": "code", "execution_count": null, "id": "e171676d-ba03-7a92-382e-74e494de-6de", "metadata": {}, "outputs": [], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)\n", "\n", "res_phi_rff = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_phi_rff.summary_text())" ] }, { "cell_type": "markdown", "id": "421d9f16-ce0f-3565-82ac-1c5c8b3e-1d5", "metadata": {}, "source": [ "## Example 3: KNN catchment basis (nearest-neighbor matching)\n", "\n", "Nearest-neighbor matching as a special case of squared-loss Riesz regression." ] }, { "cell_type": "code", "execution_count": null, "id": "2b080415-87fe-c45e-a5aa-f3f68e1d-eaa", "metadata": {}, "outputs": [], "source": [ "basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)\n", "phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)\n", "\n", "res_phi_knn = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi_knn,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_phi_knn.summary_text())" ] }, { "cell_type": "markdown", "id": "b739a4df-dd09-03f5-6aae-69abb216-519", "metadata": {}, "source": [ "## Example 4: Random forest leaf basis (optional)" ] }, { "cell_type": "code", "execution_count": null, "id": "6cee9022-23d0-21aa-f19c-2fec0e01-85a", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from genriesz.sklearn_basis import RandomForestLeafBasis\n", "\n", "rf = RandomForestRegressor(n_estimators=200, max_depth=6, random_state=0)\n", "leaf_basis = RandomForestLeafBasis(rf).fit(X, Y1 - Y0)\n", "\n", "res_leaf_basis = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=leaf_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_leaf_basis.summary_text())" ] }, { "cell_type": "markdown", "id": "048e3c39-7b39-94c0-3dd5-cf13a533-d35", "metadata": {}, "source": [ "## Example 5: Neural network embedding basis (optional)" ] }, { "cell_type": "code", "execution_count": null, "id": "b42099ca-2ad5-fb5c-4076-543cdc7e-727", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis\n", "\n", "torch.manual_seed(0)\n", "net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)\n", "nn_basis = TorchEmbeddingBasis(net, include_bias=True, device=\"cpu\")\n", "\n", "res_nn_basis = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=nn_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_nn_basis.summary_text())" ] }, { "cell_type": "markdown", "id": "e1855369", "metadata": {}, "source": [ "## Generator / regularization sweep (SQ / UKL / BP)\n", "\n", "We repeat the DID estimation (implemented as ATT on the differenced outcome)\n", "under SQ-Riesz / UKL-Riesz / BP-Riesz, multiple regularization norms, and\n", "multiple regularization strengths.\n", "\n", "For UKL/BP we set a branch function to match the treatment/control sign pattern." ] }, { "cell_type": "code", "execution_count": 15, "id": "5a7e5516", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
generatorpenaltylamrara_sera_errrwrw_serw_errarwarw_searw_errtmletmle_setmle_err
6UKL (C=1)l10.00010.9831610.013075-0.0168390.9748520.051020-0.0251480.9915540.043837-0.0084460.9906100.043874-0.009390
4UKL (C=1)l20.00010.9831610.013075-0.0168390.9747130.051015-0.0252870.9915460.043832-0.0084540.9906030.043870-0.009397
5UKL (C=1)l20.00100.9831610.013075-0.0168390.9751790.050895-0.0248210.9913560.043705-0.0086440.9904840.043742-0.009516
7UKL (C=1)lp0.00100.9831610.013075-0.0168390.9751790.050895-0.0248210.9913560.043705-0.0086440.9904840.043742-0.009516
19BP (omega=0.5, C=1)lp0.00100.9831610.013075-0.0168390.9907810.050259-0.0092190.9911990.042990-0.0088010.9906980.043028-0.009302
17BP (omega=0.5, C=1)l20.00100.9831610.013075-0.0168390.9907810.050259-0.0092190.9911990.042990-0.0088010.9906980.043028-0.009302
16BP (omega=0.5, C=1)l20.00010.9831610.013075-0.0168390.9865360.050288-0.0134640.9908300.043032-0.0091700.9903160.043068-0.009684
18BP (omega=0.5, C=1)l10.00010.9831610.013075-0.0168390.9862720.050383-0.0137280.9905580.043127-0.0094420.9900500.043162-0.009950
11BP (omega=0.1, C=1)lp0.00100.9831610.013075-0.0168391.0098910.0466370.0098910.9864270.039662-0.0135730.9866030.039681-0.013397
9BP (omega=0.1, C=1)l20.00100.9831610.013075-0.0168391.0098910.0466370.0098910.9864270.039662-0.0135730.9866030.039681-0.013397
8BP (omega=0.1, C=1)l20.00010.9831610.013075-0.0168391.0098280.0466490.0098280.9864130.039671-0.0135870.9865870.039689-0.013413
10BP (omega=0.1, C=1)l10.00010.9831610.013075-0.0168391.0092800.0466420.0092800.9863670.039666-0.0136330.9865390.039684-0.013461
3SQlp0.00100.9831610.013075-0.0168390.9758790.049509-0.0241210.9860510.042223-0.0139490.9860040.042238-0.013996
1SQl20.00100.9831610.013075-0.0168390.9758790.049509-0.0241210.9860510.042223-0.0139490.9860040.042238-0.013996
0SQl20.00010.9831610.013075-0.0168390.9789160.049698-0.0210840.9859680.042369-0.0140320.9859120.042383-0.014088
2SQl10.00010.9831610.013075-0.0168390.9791280.049712-0.0208720.9859440.042379-0.0140560.9858870.042394-0.014113
13BP (omega=0.2, C=1)l20.00100.9831610.013075-0.0168391.0117290.0484230.0117290.9807710.041029-0.0192290.9807380.041015-0.019262
15BP (omega=0.2, C=1)lp0.00100.9831610.013075-0.0168391.0117290.0484230.0117290.9807710.041029-0.0192290.9807380.041015-0.019262
14BP (omega=0.2, C=1)l10.00010.9831610.013075-0.0168391.0109950.0484840.0109950.9806710.041077-0.0193290.9806400.041062-0.019360
12BP (omega=0.2, C=1)l20.00010.9831610.013075-0.0168391.0107250.0484870.0107250.9806610.041081-0.0193390.9806300.041066-0.019370
\n", "
" ], "text/plain": [ " generator penalty lam ra ra_se ra_err \\\n", "6 UKL (C=1) l1 0.0001 0.983161 0.013075 -0.016839 \n", "4 UKL (C=1) l2 0.0001 0.983161 0.013075 -0.016839 \n", "5 UKL (C=1) l2 0.0010 0.983161 0.013075 -0.016839 \n", "7 UKL (C=1) lp 0.0010 0.983161 0.013075 -0.016839 \n", "19 BP (omega=0.5, C=1) lp 0.0010 0.983161 0.013075 -0.016839 \n", "17 BP (omega=0.5, C=1) l2 0.0010 0.983161 0.013075 -0.016839 \n", "16 BP (omega=0.5, C=1) l2 0.0001 0.983161 0.013075 -0.016839 \n", "18 BP (omega=0.5, C=1) l1 0.0001 0.983161 0.013075 -0.016839 \n", "11 BP (omega=0.1, C=1) lp 0.0010 0.983161 0.013075 -0.016839 \n", "9 BP (omega=0.1, C=1) l2 0.0010 0.983161 0.013075 -0.016839 \n", "8 BP (omega=0.1, C=1) l2 0.0001 0.983161 0.013075 -0.016839 \n", "10 BP (omega=0.1, C=1) l1 0.0001 0.983161 0.013075 -0.016839 \n", "3 SQ lp 0.0010 0.983161 0.013075 -0.016839 \n", "1 SQ l2 0.0010 0.983161 0.013075 -0.016839 \n", "0 SQ l2 0.0001 0.983161 0.013075 -0.016839 \n", "2 SQ l1 0.0001 0.983161 0.013075 -0.016839 \n", "13 BP (omega=0.2, C=1) l2 0.0010 0.983161 0.013075 -0.016839 \n", "15 BP (omega=0.2, C=1) lp 0.0010 0.983161 0.013075 -0.016839 \n", "14 BP (omega=0.2, C=1) l1 0.0001 0.983161 0.013075 -0.016839 \n", "12 BP (omega=0.2, C=1) l2 0.0001 0.983161 0.013075 -0.016839 \n", "\n", " rw rw_se rw_err arw arw_se arw_err tmle \\\n", "6 0.974852 0.051020 -0.025148 0.991554 0.043837 -0.008446 0.990610 \n", "4 0.974713 0.051015 -0.025287 0.991546 0.043832 -0.008454 0.990603 \n", "5 0.975179 0.050895 -0.024821 0.991356 0.043705 -0.008644 0.990484 \n", "7 0.975179 0.050895 -0.024821 0.991356 0.043705 -0.008644 0.990484 \n", "19 0.990781 0.050259 -0.009219 0.991199 0.042990 -0.008801 0.990698 \n", "17 0.990781 0.050259 -0.009219 0.991199 0.042990 -0.008801 0.990698 \n", "16 0.986536 0.050288 -0.013464 0.990830 0.043032 -0.009170 0.990316 \n", "18 0.986272 0.050383 -0.013728 0.990558 0.043127 -0.009442 0.990050 \n", "11 1.009891 0.046637 0.009891 0.986427 0.039662 -0.013573 0.986603 \n", "9 1.009891 0.046637 0.009891 0.986427 0.039662 -0.013573 0.986603 \n", "8 1.009828 0.046649 0.009828 0.986413 0.039671 -0.013587 0.986587 \n", "10 1.009280 0.046642 0.009280 0.986367 0.039666 -0.013633 0.986539 \n", "3 0.975879 0.049509 -0.024121 0.986051 0.042223 -0.013949 0.986004 \n", "1 0.975879 0.049509 -0.024121 0.986051 0.042223 -0.013949 0.986004 \n", "0 0.978916 0.049698 -0.021084 0.985968 0.042369 -0.014032 0.985912 \n", "2 0.979128 0.049712 -0.020872 0.985944 0.042379 -0.014056 0.985887 \n", "13 1.011729 0.048423 0.011729 0.980771 0.041029 -0.019229 0.980738 \n", "15 1.011729 0.048423 0.011729 0.980771 0.041029 -0.019229 0.980738 \n", "14 1.010995 0.048484 0.010995 0.980671 0.041077 -0.019329 0.980640 \n", "12 1.010725 0.048487 0.010725 0.980661 0.041081 -0.019339 0.980630 \n", "\n", " tmle_se tmle_err \n", "6 0.043874 -0.009390 \n", "4 0.043870 -0.009397 \n", "5 0.043742 -0.009516 \n", "7 0.043742 -0.009516 \n", "19 0.043028 -0.009302 \n", "17 0.043028 -0.009302 \n", "16 0.043068 -0.009684 \n", "18 0.043162 -0.009950 \n", "11 0.039681 -0.013397 \n", "9 0.039681 -0.013397 \n", "8 0.039689 -0.013413 \n", "10 0.039684 -0.013461 \n", "3 0.042238 -0.013996 \n", "1 0.042238 -0.013996 \n", "0 0.042383 -0.014088 \n", "2 0.042394 -0.014113 \n", "13 0.041015 -0.019262 \n", "15 0.041015 -0.019262 \n", "14 0.041062 -0.019360 \n", "12 0.041066 -0.019370 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "branch = lambda x: int(x[0] == 1.0)\n", "\n", "generator_grid = [\n", " (\"SQ\", SquaredGenerator(C=0.0).as_generator()),\n", " (\"UKL (C=1)\", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.1, C=1)\", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.2, C=1)\", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),\n", " (\"BP (omega=0.5, C=1)\", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),\n", "]\n", "\n", "penalty_grid = [\n", " {\"penalty\": \"l2\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"l2\", \"lam\": 1e-3, \"p_norm\": None},\n", " {\"penalty\": \"l1\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"lp\", \"lam\": 1e-3, \"p_norm\": 1.5},\n", "]\n", "\n", "rows = []\n", "for gname, gen_i in generator_grid:\n", " for cfg in penalty_grid:\n", " res_i = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi,\n", " generator=gen_i,\n", " cross_fit=True,\n", " folds=3,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " outcome_link=\"identity\",\n", " riesz_penalty=cfg[\"penalty\"],\n", " riesz_lam=cfg[\"lam\"],\n", " riesz_p_norm=cfg.get(\"p_norm\"),\n", " max_iter=250,\n", " tol=1e-8,\n", " )\n", "\n", " row = {\n", " \"generator\": gname,\n", " \"penalty\": cfg[\"penalty\"],\n", " \"lam\": cfg[\"lam\"],\n", " }\n", " for k in (\"ra\", \"rw\", \"arw\", \"tmle\"):\n", " e = res_i.estimates[k]\n", " row[f\"{k}\"] = e.estimate\n", " row[f\"{k}_se\"] = e.se\n", " row[f\"{k}_err\"] = e.estimate - tau_true\n", " rows.append(row)\n", "\n", "import pandas as pd\n", "\n", "df = pd.DataFrame(rows)\n", "df = df.sort_values(by=\"arw_err\", key=lambda s: np.abs(s))\n", "display(df)" ] }, { "cell_type": "code", "execution_count": null, "id": "aec3343d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.6" } }, "nbformat": 4, "nbformat_minor": 5 }