{ "cells": [ { "cell_type": "markdown", "id": "d262ca93", "metadata": {}, "source": [ "# Panel DID simulation with true value (genriesz)\n", "\n", "We implement DID as **ATT on** the differenced outcome\n", "\n", "$$\n", "\\Delta Y = Y_1 - Y_0,\n", "$$\n", "\n", "where:\n", "\n", "- $Y0$ is the pre-period outcome,\n", "- $Y1$ is the post-period outcome,\n", "- the same units are observed in both periods (panel),\n", "- $D$ is a binary treatment indicator (treatment happens in the post period).\n", "\n", "With a standard panel DID setup:\n", "\n", "$$\n", "Y_{0} = \\mu(Z) + u + \\varepsilon_0, \\qquad\n", "Y_{1} = \\mu(Z) + \\text{trend}(Z) + u + \\tau D + \\varepsilon_1,\n", "$$\n", "\n", "the DID effect equals the constant treatment effect $\\tau$, provided the\n", "parallel trends condition holds after conditioning on $Z$.\n", "\n", "This notebook:\n", "\n", "1. simulates a large population to compute an approximate \"true\" DID effect,\n", "2. samples a dataset and calls `genriesz.grr_did(X, Y0=..., Y1=...)`.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "a7a4d81c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from genriesz import (\n", " grr_did,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " BPGenerator,\n", " PolynomialBasis,\n", " TreatmentInteractionBasis,\n", " RBFRandomFourierBasis,\n", " KNNCatchmentBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)" ] }, { "cell_type": "markdown", "id": "748b75c3", "metadata": {}, "source": [ "## DGP" ] }, { "cell_type": "code", "execution_count": 7, "id": "eaa7cf23", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True tau (by construction): 1.0\n", "Naive DID (difference in mean \u0394Y): 1.0477956893531037\n" ] } ], "source": [ "def draw_panel(n: int, d_z: int, tau: float, seed: int = 0):\n", " rng = np.random.default_rng(seed)\n", " Z = rng.normal(size=(n, d_z))\n", "\n", " logits = 0.6 * Z[:, 0] - 0.25 * Z[:, 1]\n", " e = 1.0 / (1.0 + np.exp(-logits))\n", " D = rng.binomial(1, e, size=n).astype(int)\n", "\n", " mu = 0.5 * Z[:, 0] - 0.2 * Z[:, 1] ** 2\n", " trend = 0.5 + 0.1 * Z[:, 0] # common trend that depends on Z\n", "\n", " u = rng.normal(scale=1.0, size=n) # unit fixed effect\n", "\n", " Y0 = mu + u + rng.normal(scale=1.0, size=n)\n", " Y1 = mu + trend + u + tau * D + rng.normal(scale=1.0, size=n)\n", "\n", " X = np.column_stack([D.astype(float), Z])\n", " return X, Y0, Y1, D\n", "\n", "tau_true = 1.0\n", "\n", "# Large population for an approximate truth\n", "X_pop, Y0_pop, Y1_pop, D_pop = draw_panel(n=200_000, d_z=5, tau=tau_true, seed=1)\n", "\n", "true_did = np.mean((Y1_pop - Y0_pop)[D_pop == 1]) - np.mean((Y1_pop - Y0_pop)[D_pop == 0]) # naive DID\n", "# Our target here is \"ATT on \u0394Y\", whose true value equals tau_true by construction.\n", "print(\"True tau (by construction):\", tau_true)\n", "print(\"Naive DID (difference in mean \u0394Y):\", true_did)\n" ] }, { "cell_type": "markdown", "id": "cfaf7418", "metadata": {}, "source": [ "## Example 1: Polynomial basis + treatment interactions" ] }, { "cell_type": "code", "execution_count": 9, "id": "fc6d817e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DID estimates (n=6000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: max_abs_smd_unweighted=0.5468841195092563, max_abs_smd_weighted=0.002194706023289376, ess_treated=3017.871638373246, ess_control=2036.5721815918366\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 0.983907 0.0129607 [ 0.958505, 1.00931] 0\n", "RW 0.976091 0.0495386 [ 0.878997, 1.07318] 0\n", "ARW 0.98479 0.0421389 [ 0.9022, 1.06738] 0\n", "TMLE 0.984776 0.0421433 [ 0.902176, 1.06738] 0\n" ] } ], "source": [ "# Sample a dataset from the same DGP\n", "X, Y0, Y1, D = draw_panel(n=6000, d_z=5, tau=tau_true, seed=0)\n", "\n", "psi = PolynomialBasis(degree=2, include_bias=True)\n", "phi = TreatmentInteractionBasis(base_basis=psi)\n", "\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res.summary_text())\n" ] }, { "cell_type": "markdown", "id": "3309d26f-090f-da23-565a-d314e376-7b1", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)" ] }, { "cell_type": "code", "execution_count": null, "id": "e171676d-ba03-7a92-382e-74e494de-6de", "metadata": {}, "outputs": [], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)\n", "\n", "res_phi_rff = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_phi_rff.summary_text())" ] }, { "cell_type": "markdown", "id": "421d9f16-ce0f-3565-82ac-1c5c8b3e-1d5", "metadata": {}, "source": [ "## Example 3: KNN catchment basis (nearest-neighbor matching)\n", "\n", "Nearest-neighbor matching as a special case of squared-loss Riesz regression." ] }, { "cell_type": "code", "execution_count": null, "id": "2b080415-87fe-c45e-a5aa-f3f68e1d-eaa", "metadata": {}, "outputs": [], "source": [ "basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)\n", "phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)\n", "\n", "res_phi_knn = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=phi_knn,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_phi_knn.summary_text())" ] }, { "cell_type": "markdown", "id": "b739a4df-dd09-03f5-6aae-69abb216-519", "metadata": {}, "source": [ "## Example 4: Random forest leaf basis (optional)" ] }, { "cell_type": "code", "execution_count": null, "id": "6cee9022-23d0-21aa-f19c-2fec0e01-85a", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from genriesz.sklearn_basis import RandomForestLeafBasis\n", "\n", "rf = RandomForestRegressor(n_estimators=200, max_depth=6, random_state=0)\n", "leaf_basis = RandomForestLeafBasis(rf).fit(X, Y1 - Y0)\n", "\n", "res_leaf_basis = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=leaf_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_leaf_basis.summary_text())" ] }, { "cell_type": "markdown", "id": "048e3c39-7b39-94c0-3dd5-cf13a533-d35", "metadata": {}, "source": [ "## Example 5: Neural network embedding basis (optional)" ] }, { "cell_type": "code", "execution_count": null, "id": "b42099ca-2ad5-fb5c-4076-543cdc7e-727", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis\n", "\n", "torch.manual_seed(0)\n", "net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)\n", "nn_basis = TorchEmbeddingBasis(net, include_bias=True, device=\"cpu\")\n", "\n", "res_nn_basis = grr_did(\n", " X=X,\n", " Y0=Y0,\n", " Y1=Y1,\n", " basis=nn_basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_nn_basis.summary_text())" ] }, { "cell_type": "markdown", "id": "e1855369", "metadata": {}, "source": [ "## Generator / regularization sweep (SQ / UKL / BP)\n", "\n", "We repeat the DID estimation (implemented as ATT on the differenced outcome)\n", "under SQ-Riesz / UKL-Riesz / BP-Riesz, multiple regularization norms, and\n", "multiple regularization strengths.\n", "\n", "For UKL/BP we set a branch function to match the treatment/control sign pattern." ] }, { "cell_type": "code", "execution_count": 15, "id": "5a7e5516", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | generator | \n", "penalty | \n", "lam | \n", "ra | \n", "ra_se | \n", "ra_err | \n", "rw | \n", "rw_se | \n", "rw_err | \n", "arw | \n", "arw_se | \n", "arw_err | \n", "tmle | \n", "tmle_se | \n", "tmle_err | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 6 | \n", "UKL (C=1) | \n", "l1 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.974852 | \n", "0.051020 | \n", "-0.025148 | \n", "0.991554 | \n", "0.043837 | \n", "-0.008446 | \n", "0.990610 | \n", "0.043874 | \n", "-0.009390 | \n", "
| 4 | \n", "UKL (C=1) | \n", "l2 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.974713 | \n", "0.051015 | \n", "-0.025287 | \n", "0.991546 | \n", "0.043832 | \n", "-0.008454 | \n", "0.990603 | \n", "0.043870 | \n", "-0.009397 | \n", "
| 5 | \n", "UKL (C=1) | \n", "l2 | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.975179 | \n", "0.050895 | \n", "-0.024821 | \n", "0.991356 | \n", "0.043705 | \n", "-0.008644 | \n", "0.990484 | \n", "0.043742 | \n", "-0.009516 | \n", "
| 7 | \n", "UKL (C=1) | \n", "lp | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.975179 | \n", "0.050895 | \n", "-0.024821 | \n", "0.991356 | \n", "0.043705 | \n", "-0.008644 | \n", "0.990484 | \n", "0.043742 | \n", "-0.009516 | \n", "
| 19 | \n", "BP (omega=0.5, C=1) | \n", "lp | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.990781 | \n", "0.050259 | \n", "-0.009219 | \n", "0.991199 | \n", "0.042990 | \n", "-0.008801 | \n", "0.990698 | \n", "0.043028 | \n", "-0.009302 | \n", "
| 17 | \n", "BP (omega=0.5, C=1) | \n", "l2 | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.990781 | \n", "0.050259 | \n", "-0.009219 | \n", "0.991199 | \n", "0.042990 | \n", "-0.008801 | \n", "0.990698 | \n", "0.043028 | \n", "-0.009302 | \n", "
| 16 | \n", "BP (omega=0.5, C=1) | \n", "l2 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.986536 | \n", "0.050288 | \n", "-0.013464 | \n", "0.990830 | \n", "0.043032 | \n", "-0.009170 | \n", "0.990316 | \n", "0.043068 | \n", "-0.009684 | \n", "
| 18 | \n", "BP (omega=0.5, C=1) | \n", "l1 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.986272 | \n", "0.050383 | \n", "-0.013728 | \n", "0.990558 | \n", "0.043127 | \n", "-0.009442 | \n", "0.990050 | \n", "0.043162 | \n", "-0.009950 | \n", "
| 11 | \n", "BP (omega=0.1, C=1) | \n", "lp | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.009891 | \n", "0.046637 | \n", "0.009891 | \n", "0.986427 | \n", "0.039662 | \n", "-0.013573 | \n", "0.986603 | \n", "0.039681 | \n", "-0.013397 | \n", "
| 9 | \n", "BP (omega=0.1, C=1) | \n", "l2 | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.009891 | \n", "0.046637 | \n", "0.009891 | \n", "0.986427 | \n", "0.039662 | \n", "-0.013573 | \n", "0.986603 | \n", "0.039681 | \n", "-0.013397 | \n", "
| 8 | \n", "BP (omega=0.1, C=1) | \n", "l2 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.009828 | \n", "0.046649 | \n", "0.009828 | \n", "0.986413 | \n", "0.039671 | \n", "-0.013587 | \n", "0.986587 | \n", "0.039689 | \n", "-0.013413 | \n", "
| 10 | \n", "BP (omega=0.1, C=1) | \n", "l1 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.009280 | \n", "0.046642 | \n", "0.009280 | \n", "0.986367 | \n", "0.039666 | \n", "-0.013633 | \n", "0.986539 | \n", "0.039684 | \n", "-0.013461 | \n", "
| 3 | \n", "SQ | \n", "lp | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.975879 | \n", "0.049509 | \n", "-0.024121 | \n", "0.986051 | \n", "0.042223 | \n", "-0.013949 | \n", "0.986004 | \n", "0.042238 | \n", "-0.013996 | \n", "
| 1 | \n", "SQ | \n", "l2 | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.975879 | \n", "0.049509 | \n", "-0.024121 | \n", "0.986051 | \n", "0.042223 | \n", "-0.013949 | \n", "0.986004 | \n", "0.042238 | \n", "-0.013996 | \n", "
| 0 | \n", "SQ | \n", "l2 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.978916 | \n", "0.049698 | \n", "-0.021084 | \n", "0.985968 | \n", "0.042369 | \n", "-0.014032 | \n", "0.985912 | \n", "0.042383 | \n", "-0.014088 | \n", "
| 2 | \n", "SQ | \n", "l1 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "0.979128 | \n", "0.049712 | \n", "-0.020872 | \n", "0.985944 | \n", "0.042379 | \n", "-0.014056 | \n", "0.985887 | \n", "0.042394 | \n", "-0.014113 | \n", "
| 13 | \n", "BP (omega=0.2, C=1) | \n", "l2 | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.011729 | \n", "0.048423 | \n", "0.011729 | \n", "0.980771 | \n", "0.041029 | \n", "-0.019229 | \n", "0.980738 | \n", "0.041015 | \n", "-0.019262 | \n", "
| 15 | \n", "BP (omega=0.2, C=1) | \n", "lp | \n", "0.0010 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.011729 | \n", "0.048423 | \n", "0.011729 | \n", "0.980771 | \n", "0.041029 | \n", "-0.019229 | \n", "0.980738 | \n", "0.041015 | \n", "-0.019262 | \n", "
| 14 | \n", "BP (omega=0.2, C=1) | \n", "l1 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.010995 | \n", "0.048484 | \n", "0.010995 | \n", "0.980671 | \n", "0.041077 | \n", "-0.019329 | \n", "0.980640 | \n", "0.041062 | \n", "-0.019360 | \n", "
| 12 | \n", "BP (omega=0.2, C=1) | \n", "l2 | \n", "0.0001 | \n", "0.983161 | \n", "0.013075 | \n", "-0.016839 | \n", "1.010725 | \n", "0.048487 | \n", "0.010725 | \n", "0.980661 | \n", "0.041081 | \n", "-0.019339 | \n", "0.980630 | \n", "0.041066 | \n", "-0.019370 | \n", "