{ "cells": [ { "cell_type": "markdown", "id": "1346260d", "metadata": {}, "source": [ "# AME end-to-end example (genriesz)\n", "\n", "This notebook demonstrates how to estimate an **Average Marginal Effect (AME)**,\n", "i.e., an **average derivative** of the outcome regression function.\n", "\n", "We simulate\n", "\n", "$$\n", "Y = \\sin(X_0) + 0.5 X_1^2 + \\varepsilon,\n", "$$\n", "\n", "so the true AME for coordinate 0 is\n", "\n", "$$\n", "\\mathbb{E}[\\partial_{x_0} \\gamma(X)] = \\mathbb{E}[\\cos(X_0)].\n", "$$\n", "\n", "If $X_0 \\sim N(0,1)$, then $\\mathbb{E}[\\cos(X_0)] = \\exp(-1/2)$.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "610a3b97", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from genriesz import (\n", " grr_ame,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " BPGenerator,\n", " PolynomialBasis,\n", " RBFRandomFourierBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)" ] }, { "cell_type": "markdown", "id": "a02b2cb6", "metadata": {}, "source": [ "## Synthetic data with known true AME" ] }, { "cell_type": "code", "execution_count": 2, "id": "08261461", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Approx. true AME for coordinate 0: 0.6065306597126334\n" ] } ], "source": [ "n = 4000\n", "d = 3\n", "\n", "X = rng.normal(size=(n, d))\n", "eps = rng.normal(scale=1.0, size=n)\n", "\n", "Y = np.sin(X[:, 0]) + 0.5 * (X[:, 1] ** 2) + eps\n", "\n", "true_ame0 = float(np.exp(-0.5)) # E[cos(N(0,1))]\n", "print(\"Approx. true AME for coordinate 0:\", true_ame0)\n" ] }, { "cell_type": "markdown", "id": "284c9d8b", "metadata": {}, "source": [ "## Example 1: Polynomial basis" ] }, { "cell_type": "code", "execution_count": 3, "id": "ba70cc37", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AME(coord=0) estimates (n=4000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: alpha_abs_mean=0.7986024158798944, alpha_abs_p95=1.9652359398234305, alpha_abs_max=4.506401055576029\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 0.568023 0.00659669 [ 0.555094, 0.580952] 0\n", "RW 0.569997 0.0255014 [ 0.520015, 0.619979] 0\n", "ARW 0.568096 0.0169592 [ 0.534856, 0.601335] 0\n", "TMLE 0.568094 0.0169599 [ 0.534853, 0.601334] 0\n" ] } ], "source": [ "# A simple polynomial basis on X\n", "basis = PolynomialBasis(degree=3, include_bias=True)\n", "\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res = grr_ame(\n", " X=X,\n", " Y=Y,\n", " coordinate=0,\n", " basis=basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res.summary_text())\n" ] }, { "cell_type": "markdown", "id": "a2711c49-8a1e-a4b8-2454-ce35a31f-53e", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)\n", "\n", "RBF random Fourier features are smooth and differentiable, so the AME derivative\n", "``d phi / d x_j`` is implemented analytically." ] }, { "cell_type": "code", "execution_count": null, "id": "9c7c56d5-2dd8-1638-1c58-d1981720-bc9", "metadata": {}, "outputs": [], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "\n", "res_rff = grr_ame(\n", " X=X,\n", " Y=Y,\n", " coordinate=0,\n", " basis=psi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rff.summary_text())" ] }, { "cell_type": "markdown", "id": "25f54dc4-56ff-d956-11c9-f4c7b07c-ece", "metadata": {}, "source": [ "## Note: bases requiring smooth derivatives\n", "\n", "AME requires ``basis.derivative(X, coordinate)``. **Piecewise-constant** bases\n", "(``KNNCatchmentBasis``, ``RandomForestLeafBasis``, ``TorchEmbeddingBasis`` without\n", "autograd) do not implement this method and therefore cannot be used with ``grr_ame``.\n", "Use ``PolynomialBasis`` or ``RBFRandomFourierBasis`` (or any smooth basis) instead." ] }, { "cell_type": "markdown", "id": "7f6869ff", "metadata": {}, "source": [ "## Generator sweep (SQ / UKL / BP)\n", "\n", "Below we compare SQ-Riesz, UKL-Riesz, and BP-Riesz under multiple regularization\n", "norms and strengths. We report **RA / RW / ARW / TMLE** and the error against\n", "the known true AME." ] }, { "cell_type": "code", "execution_count": 4, "id": "bf77f950", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:6: UserWarning: UKLGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| > C + 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"UKL (C=0)\", UKLGenerator(C=0.0).as_generator()),\n", "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:7: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"BP (omega=0.1, C=0)\", BPGenerator(C=0.0, omega=0.1).as_generator()),\n", "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:8: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"BP (omega=0.2, C=0)\", BPGenerator(C=0.0, omega=0.2).as_generator()),\n", "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:9: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"BP (omega=0.5, C=0)\", BPGenerator(C=0.0, omega=0.5).as_generator()),\n" ] }, { "data": { "text/html": [ "
| \n", " | generator | \n", "penalty | \n", "lam | \n", "ra | \n", "ra_se | \n", "ra_err | \n", "rw | \n", "rw_se | \n", "rw_err | \n", "arw | \n", "arw_se | \n", "arw_err | \n", "tmle | \n", "tmle_se | \n", "tmle_err | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 19 | \n", "BP (omega=0.5, C=0) | \n", "l1 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.583555 | \n", "0.026705 | \n", "-0.022976 | \n", "0.582595 | \n", "0.019655 | \n", "-0.023935 | \n", "0.569430 | \n", "0.019682 | \n", "-0.037101 | \n", "
| 17 | \n", "BP (omega=0.5, C=0) | \n", "l2 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.584187 | \n", "0.026752 | \n", "-0.022344 | \n", "0.581640 | \n", "0.019659 | \n", "-0.024891 | \n", "0.569258 | \n", "0.019684 | \n", "-0.037273 | \n", "
| 16 | \n", "BP (omega=0.5, C=0) | \n", "l2 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.583067 | \n", "0.026758 | \n", "-0.023463 | \n", "0.581609 | \n", "0.019660 | \n", "-0.024921 | \n", "0.569253 | \n", "0.019686 | \n", "-0.037278 | \n", "
| 18 | \n", "BP (omega=0.5, C=0) | \n", "l1 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.583012 | \n", "0.026763 | \n", "-0.023518 | \n", "0.580149 | \n", "0.019657 | \n", "-0.026382 | \n", "0.568983 | \n", "0.019680 | \n", "-0.037547 | \n", "
| 9 | \n", "BP (omega=0.1, C=0) | \n", "l2 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.578461 | \n", "0.027031 | \n", "-0.028070 | \n", "0.578600 | \n", "0.019622 | \n", "-0.027931 | \n", "0.568672 | \n", "0.019629 | \n", "-0.037858 | \n", "
| 5 | \n", "UKL (C=0) | \n", "l2 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.578919 | \n", "0.027191 | \n", "-0.027611 | \n", "0.578271 | \n", "0.019641 | \n", "-0.028260 | \n", "0.568613 | \n", "0.019639 | \n", "-0.037917 | \n", "
| 6 | \n", "UKL (C=0) | \n", "l1 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.578065 | \n", "0.027179 | \n", "-0.028466 | \n", "0.578176 | \n", "0.019640 | \n", "-0.028355 | \n", "0.568595 | \n", "0.019637 | \n", "-0.037935 | \n", "
| 10 | \n", "BP (omega=0.1, C=0) | \n", "l1 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.578894 | \n", "0.027017 | \n", "-0.027637 | \n", "0.577886 | \n", "0.019615 | \n", "-0.028645 | \n", "0.568540 | \n", "0.019622 | \n", "-0.037991 | \n", "
| 8 | \n", "BP (omega=0.1, C=0) | \n", "l2 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.577012 | \n", "0.027054 | \n", "-0.029518 | \n", "0.577699 | \n", "0.019618 | \n", "-0.028832 | \n", "0.568507 | \n", "0.019625 | \n", "-0.038024 | \n", "
| 13 | \n", "BP (omega=0.2, C=0) | \n", "l2 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.578107 | \n", "0.026959 | \n", "-0.028423 | \n", "0.577668 | \n", "0.019615 | \n", "-0.028863 | \n", "0.568504 | \n", "0.019627 | \n", "-0.038026 | \n", "
| 12 | \n", "BP (omega=0.2, C=0) | \n", "l2 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.576705 | \n", "0.026951 | \n", "-0.029826 | \n", "0.577231 | \n", "0.019613 | \n", "-0.029299 | \n", "0.568424 | \n", "0.019624 | \n", "-0.038107 | \n", "
| 4 | \n", "UKL (C=0) | \n", "l2 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.575639 | \n", "0.027207 | \n", "-0.030892 | \n", "0.576981 | \n", "0.019636 | \n", "-0.029550 | \n", "0.568375 | \n", "0.019634 | \n", "-0.038156 | \n", "
| 14 | \n", "BP (omega=0.2, C=0) | \n", "l1 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.576432 | \n", "0.026943 | \n", "-0.030099 | \n", "0.576170 | \n", "0.019608 | \n", "-0.030361 | \n", "0.568229 | \n", "0.019618 | \n", "-0.038301 | \n", "
| 7 | \n", "UKL (C=0) | \n", "l1 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.574841 | \n", "0.027141 | \n", "-0.031690 | \n", "0.575717 | \n", "0.019627 | \n", "-0.030813 | \n", "0.568142 | \n", "0.019625 | \n", "-0.038389 | \n", "
| 11 | \n", "BP (omega=0.1, C=0) | \n", "l1 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.575375 | \n", "0.026993 | \n", "-0.031156 | \n", "0.575641 | \n", "0.019606 | \n", "-0.030890 | \n", "0.568128 | \n", "0.019610 | \n", "-0.038403 | \n", "
| 15 | \n", "BP (omega=0.2, C=0) | \n", "l1 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.572222 | \n", "0.026900 | \n", "-0.034309 | \n", "0.573946 | \n", "0.019594 | \n", "-0.032585 | \n", "0.567820 | \n", "0.019601 | \n", "-0.038710 | \n", "
| 1 | \n", "SQ | \n", "l2 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.575361 | \n", "0.026199 | \n", "-0.031170 | \n", "0.569663 | \n", "0.017072 | \n", "-0.036868 | \n", "0.569537 | \n", "0.017102 | \n", "-0.036993 | \n", "
| 0 | \n", "SQ | \n", "l2 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.576791 | \n", "0.026184 | \n", "-0.029739 | \n", "0.569646 | \n", "0.017092 | \n", "-0.036885 | \n", "0.569517 | \n", "0.017121 | \n", "-0.037014 | \n", "
| 2 | \n", "SQ | \n", "l1 | \n", "0.0001 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.576779 | \n", "0.026169 | \n", "-0.029752 | \n", "0.569638 | \n", "0.017091 | \n", "-0.036893 | \n", "0.569510 | \n", "0.017121 | \n", "-0.037021 | \n", "
| 3 | \n", "SQ | \n", "l1 | \n", "0.0010 | \n", "0.56645 | \n", "0.006685 | \n", "-0.040081 | \n", "0.575307 | \n", "0.026053 | \n", "-0.031223 | \n", "0.569633 | \n", "0.017068 | \n", "-0.036898 | \n", "0.569511 | \n", "0.017097 | \n", "-0.037019 | \n", "