{ "cells": [ { "cell_type": "markdown", "id": "1346260d", "metadata": {}, "source": [ "# AME end-to-end example (genriesz)\n", "\n", "This notebook demonstrates how to estimate an **Average Marginal Effect (AME)**,\n", "i.e., an **average derivative** of the outcome regression function.\n", "\n", "We simulate\n", "\n", "$$\n", "Y = \\sin(X_0) + 0.5 X_1^2 + \\varepsilon,\n", "$$\n", "\n", "so the true AME for coordinate 0 is\n", "\n", "$$\n", "\\mathbb{E}[\\partial_{x_0} \\gamma(X)] = \\mathbb{E}[\\cos(X_0)].\n", "$$\n", "\n", "If $X_0 \\sim N(0,1)$, then $\\mathbb{E}[\\cos(X_0)] = \\exp(-1/2)$.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "610a3b97", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from genriesz import (\n", " grr_ame,\n", " SquaredGenerator,\n", " UKLGenerator,\n", " BPGenerator,\n", " PolynomialBasis,\n", " RBFRandomFourierBasis,\n", ")\n", "\n", "rng = np.random.default_rng(0)" ] }, { "cell_type": "markdown", "id": "a02b2cb6", "metadata": {}, "source": [ "## Synthetic data with known true AME" ] }, { "cell_type": "code", "execution_count": 2, "id": "08261461", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Approx. true AME for coordinate 0: 0.6065306597126334\n" ] } ], "source": [ "n = 4000\n", "d = 3\n", "\n", "X = rng.normal(size=(n, d))\n", "eps = rng.normal(scale=1.0, size=n)\n", "\n", "Y = np.sin(X[:, 0]) + 0.5 * (X[:, 1] ** 2) + eps\n", "\n", "true_ame0 = float(np.exp(-0.5)) # E[cos(N(0,1))]\n", "print(\"Approx. true AME for coordinate 0:\", true_ame0)\n" ] }, { "cell_type": "markdown", "id": "284c9d8b", "metadata": {}, "source": [ "## Example 1: Polynomial basis" ] }, { "cell_type": "code", "execution_count": 3, "id": "ba70cc37", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AME(coord=0) estimates (n=4000)\n", "alpha=0.05 | null=0.0\n", "diagnostics: alpha_abs_mean=0.7986024158798944, alpha_abs_p95=1.9652359398234305, alpha_abs_max=4.506401055576029\n", "\n", "Estimator Estimate SE CI p-value\n", "---------------------------------------------------------------------------------\n", "RA 0.568023 0.00659669 [ 0.555094, 0.580952] 0\n", "RW 0.569997 0.0255014 [ 0.520015, 0.619979] 0\n", "ARW 0.568096 0.0169592 [ 0.534856, 0.601335] 0\n", "TMLE 0.568094 0.0169599 [ 0.534853, 0.601334] 0\n" ] } ], "source": [ "# A simple polynomial basis on X\n", "basis = PolynomialBasis(degree=3, include_bias=True)\n", "\n", "gen = SquaredGenerator(C=0.0).as_generator()\n", "\n", "res = grr_ame(\n", " X=X,\n", " Y=Y,\n", " coordinate=0,\n", " basis=basis,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res.summary_text())\n" ] }, { "cell_type": "markdown", "id": "a2711c49-8a1e-a4b8-2454-ce35a31f-53e", "metadata": {}, "source": [ "## Example 2: RKHS basis (RBF random Fourier features)\n", "\n", "RBF random Fourier features are smooth and differentiable, so the AME derivative\n", "``d phi / d x_j`` is implemented analytically." ] }, { "cell_type": "code", "execution_count": null, "id": "9c7c56d5-2dd8-1638-1c58-d1981720-bc9", "metadata": {}, "outputs": [], "source": [ "psi_rff = RBFRandomFourierBasis(\n", " n_features=500,\n", " sigma=1.0,\n", " standardize=True,\n", " random_state=0,\n", ")\n", "\n", "res_rff = grr_ame(\n", " X=X,\n", " Y=Y,\n", " coordinate=0,\n", " basis=psi_rff,\n", " generator=gen,\n", " cross_fit=True,\n", " folds=5,\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " riesz_penalty=\"l2\",\n", " riesz_lam=1e-3,\n", " max_iter=300,\n", " tol=1e-8,\n", ")\n", "\n", "print(res_rff.summary_text())" ] }, { "cell_type": "markdown", "id": "25f54dc4-56ff-d956-11c9-f4c7b07c-ece", "metadata": {}, "source": [ "## Note: bases requiring smooth derivatives\n", "\n", "AME requires ``basis.derivative(X, coordinate)``. **Piecewise-constant** bases\n", "(``KNNCatchmentBasis``, ``RandomForestLeafBasis``, ``TorchEmbeddingBasis`` without\n", "autograd) do not implement this method and therefore cannot be used with ``grr_ame``.\n", "Use ``PolynomialBasis`` or ``RBFRandomFourierBasis`` (or any smooth basis) instead." ] }, { "cell_type": "markdown", "id": "7f6869ff", "metadata": {}, "source": [ "## Generator sweep (SQ / UKL / BP)\n", "\n", "Below we compare SQ-Riesz, UKL-Riesz, and BP-Riesz under multiple regularization\n", "norms and strengths. We report **RA / RW / ARW / TMLE** and the error against\n", "the known true AME." ] }, { "cell_type": "code", "execution_count": 4, "id": "bf77f950", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:6: UserWarning: UKLGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| > C + 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"UKL (C=0)\", UKLGenerator(C=0.0).as_generator()),\n", "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:7: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"BP (omega=0.1, C=0)\", BPGenerator(C=0.0, omega=0.1).as_generator()),\n", "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:8: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"BP (omega=0.2, C=0)\", BPGenerator(C=0.0, omega=0.2).as_generator()),\n", "/var/folders/11/m8mvh3fs3jn1tk4r49vpy8rh0000gn/T/ipykernel_17417/127633348.py:9: UserWarning: BPGenerator without branch_fn uses sign(v) to select the alpha branch. This is correct only when |alpha| - C > 1. For GRR with functionals that require negative alpha (e.g. ATE/ATT), provide branch_fn or use SquaredGenerator instead.\n", " (\"BP (omega=0.5, C=0)\", BPGenerator(C=0.0, omega=0.5).as_generator()),\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
generatorpenaltylamrara_sera_errrwrw_serw_errarwarw_searw_errtmletmle_setmle_err
19BP (omega=0.5, C=0)l10.00100.566450.006685-0.0400810.5835550.026705-0.0229760.5825950.019655-0.0239350.5694300.019682-0.037101
17BP (omega=0.5, C=0)l20.00100.566450.006685-0.0400810.5841870.026752-0.0223440.5816400.019659-0.0248910.5692580.019684-0.037273
16BP (omega=0.5, C=0)l20.00010.566450.006685-0.0400810.5830670.026758-0.0234630.5816090.019660-0.0249210.5692530.019686-0.037278
18BP (omega=0.5, C=0)l10.00010.566450.006685-0.0400810.5830120.026763-0.0235180.5801490.019657-0.0263820.5689830.019680-0.037547
9BP (omega=0.1, C=0)l20.00100.566450.006685-0.0400810.5784610.027031-0.0280700.5786000.019622-0.0279310.5686720.019629-0.037858
5UKL (C=0)l20.00100.566450.006685-0.0400810.5789190.027191-0.0276110.5782710.019641-0.0282600.5686130.019639-0.037917
6UKL (C=0)l10.00010.566450.006685-0.0400810.5780650.027179-0.0284660.5781760.019640-0.0283550.5685950.019637-0.037935
10BP (omega=0.1, C=0)l10.00010.566450.006685-0.0400810.5788940.027017-0.0276370.5778860.019615-0.0286450.5685400.019622-0.037991
8BP (omega=0.1, C=0)l20.00010.566450.006685-0.0400810.5770120.027054-0.0295180.5776990.019618-0.0288320.5685070.019625-0.038024
13BP (omega=0.2, C=0)l20.00100.566450.006685-0.0400810.5781070.026959-0.0284230.5776680.019615-0.0288630.5685040.019627-0.038026
12BP (omega=0.2, C=0)l20.00010.566450.006685-0.0400810.5767050.026951-0.0298260.5772310.019613-0.0292990.5684240.019624-0.038107
4UKL (C=0)l20.00010.566450.006685-0.0400810.5756390.027207-0.0308920.5769810.019636-0.0295500.5683750.019634-0.038156
14BP (omega=0.2, C=0)l10.00010.566450.006685-0.0400810.5764320.026943-0.0300990.5761700.019608-0.0303610.5682290.019618-0.038301
7UKL (C=0)l10.00100.566450.006685-0.0400810.5748410.027141-0.0316900.5757170.019627-0.0308130.5681420.019625-0.038389
11BP (omega=0.1, C=0)l10.00100.566450.006685-0.0400810.5753750.026993-0.0311560.5756410.019606-0.0308900.5681280.019610-0.038403
15BP (omega=0.2, C=0)l10.00100.566450.006685-0.0400810.5722220.026900-0.0343090.5739460.019594-0.0325850.5678200.019601-0.038710
1SQl20.00100.566450.006685-0.0400810.5753610.026199-0.0311700.5696630.017072-0.0368680.5695370.017102-0.036993
0SQl20.00010.566450.006685-0.0400810.5767910.026184-0.0297390.5696460.017092-0.0368850.5695170.017121-0.037014
2SQl10.00010.566450.006685-0.0400810.5767790.026169-0.0297520.5696380.017091-0.0368930.5695100.017121-0.037021
3SQl10.00100.566450.006685-0.0400810.5753070.026053-0.0312230.5696330.017068-0.0368980.5695110.017097-0.037019
\n", "
" ], "text/plain": [ " generator penalty lam ra ra_se ra_err \\\n", "19 BP (omega=0.5, C=0) l1 0.0010 0.56645 0.006685 -0.040081 \n", "17 BP (omega=0.5, C=0) l2 0.0010 0.56645 0.006685 -0.040081 \n", "16 BP (omega=0.5, C=0) l2 0.0001 0.56645 0.006685 -0.040081 \n", "18 BP (omega=0.5, C=0) l1 0.0001 0.56645 0.006685 -0.040081 \n", "9 BP (omega=0.1, C=0) l2 0.0010 0.56645 0.006685 -0.040081 \n", "5 UKL (C=0) l2 0.0010 0.56645 0.006685 -0.040081 \n", "6 UKL (C=0) l1 0.0001 0.56645 0.006685 -0.040081 \n", "10 BP (omega=0.1, C=0) l1 0.0001 0.56645 0.006685 -0.040081 \n", "8 BP (omega=0.1, C=0) l2 0.0001 0.56645 0.006685 -0.040081 \n", "13 BP (omega=0.2, C=0) l2 0.0010 0.56645 0.006685 -0.040081 \n", "12 BP (omega=0.2, C=0) l2 0.0001 0.56645 0.006685 -0.040081 \n", "4 UKL (C=0) l2 0.0001 0.56645 0.006685 -0.040081 \n", "14 BP (omega=0.2, C=0) l1 0.0001 0.56645 0.006685 -0.040081 \n", "7 UKL (C=0) l1 0.0010 0.56645 0.006685 -0.040081 \n", "11 BP (omega=0.1, C=0) l1 0.0010 0.56645 0.006685 -0.040081 \n", "15 BP (omega=0.2, C=0) l1 0.0010 0.56645 0.006685 -0.040081 \n", "1 SQ l2 0.0010 0.56645 0.006685 -0.040081 \n", "0 SQ l2 0.0001 0.56645 0.006685 -0.040081 \n", "2 SQ l1 0.0001 0.56645 0.006685 -0.040081 \n", "3 SQ l1 0.0010 0.56645 0.006685 -0.040081 \n", "\n", " rw rw_se rw_err arw arw_se arw_err tmle \\\n", "19 0.583555 0.026705 -0.022976 0.582595 0.019655 -0.023935 0.569430 \n", "17 0.584187 0.026752 -0.022344 0.581640 0.019659 -0.024891 0.569258 \n", "16 0.583067 0.026758 -0.023463 0.581609 0.019660 -0.024921 0.569253 \n", "18 0.583012 0.026763 -0.023518 0.580149 0.019657 -0.026382 0.568983 \n", "9 0.578461 0.027031 -0.028070 0.578600 0.019622 -0.027931 0.568672 \n", "5 0.578919 0.027191 -0.027611 0.578271 0.019641 -0.028260 0.568613 \n", "6 0.578065 0.027179 -0.028466 0.578176 0.019640 -0.028355 0.568595 \n", "10 0.578894 0.027017 -0.027637 0.577886 0.019615 -0.028645 0.568540 \n", "8 0.577012 0.027054 -0.029518 0.577699 0.019618 -0.028832 0.568507 \n", "13 0.578107 0.026959 -0.028423 0.577668 0.019615 -0.028863 0.568504 \n", "12 0.576705 0.026951 -0.029826 0.577231 0.019613 -0.029299 0.568424 \n", "4 0.575639 0.027207 -0.030892 0.576981 0.019636 -0.029550 0.568375 \n", "14 0.576432 0.026943 -0.030099 0.576170 0.019608 -0.030361 0.568229 \n", "7 0.574841 0.027141 -0.031690 0.575717 0.019627 -0.030813 0.568142 \n", "11 0.575375 0.026993 -0.031156 0.575641 0.019606 -0.030890 0.568128 \n", "15 0.572222 0.026900 -0.034309 0.573946 0.019594 -0.032585 0.567820 \n", "1 0.575361 0.026199 -0.031170 0.569663 0.017072 -0.036868 0.569537 \n", "0 0.576791 0.026184 -0.029739 0.569646 0.017092 -0.036885 0.569517 \n", "2 0.576779 0.026169 -0.029752 0.569638 0.017091 -0.036893 0.569510 \n", "3 0.575307 0.026053 -0.031223 0.569633 0.017068 -0.036898 0.569511 \n", "\n", " tmle_se tmle_err \n", "19 0.019682 -0.037101 \n", "17 0.019684 -0.037273 \n", "16 0.019686 -0.037278 \n", "18 0.019680 -0.037547 \n", "9 0.019629 -0.037858 \n", "5 0.019639 -0.037917 \n", "6 0.019637 -0.037935 \n", "10 0.019622 -0.037991 \n", "8 0.019625 -0.038024 \n", "13 0.019627 -0.038026 \n", "12 0.019624 -0.038107 \n", "4 0.019634 -0.038156 \n", "14 0.019618 -0.038301 \n", "7 0.019625 -0.038389 \n", "11 0.019610 -0.038403 \n", "15 0.019601 -0.038710 \n", "1 0.017102 -0.036993 \n", "0 0.017121 -0.037014 \n", "2 0.017121 -0.037021 \n", "3 0.017097 -0.037019 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# A small grid over generators and regularization.\n", "generator_grid = [\n", " (\"SQ\", SquaredGenerator(C=0.0).as_generator()),\n", " (\"UKL (C=0)\", UKLGenerator(C=0.0).as_generator()),\n", " (\"BP (omega=0.1, C=0)\", BPGenerator(C=0.0, omega=0.1).as_generator()),\n", " (\"BP (omega=0.2, C=0)\", BPGenerator(C=0.0, omega=0.2).as_generator()),\n", " (\"BP (omega=0.5, C=0)\", BPGenerator(C=0.0, omega=0.5).as_generator()),\n", "]\n", "\n", "penalty_grid = [\n", " {\"penalty\": \"l2\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"l2\", \"lam\": 1e-3, \"p_norm\": None},\n", " {\"penalty\": \"l1\", \"lam\": 1e-4, \"p_norm\": None},\n", " {\"penalty\": \"lp\", \"lam\": 1e-3, \"p_norm\": 1.5},\n", "]\n", "\n", "rows = []\n", "for gname, gen_i in generator_grid:\n", " for cfg in penalty_grid:\n", " res_i = grr_ame(\n", " X=X,\n", " Y=Y,\n", " coordinate=0,\n", " basis=basis,\n", " generator=gen_i,\n", " cross_fit=True,\n", " folds=3, # smaller folds for the sweep\n", " random_state=0,\n", " estimators=(\"ra\", \"rw\", \"arw\", \"tmle\"),\n", " outcome_models=\"shared\",\n", " outcome_link=\"identity\", # Y is unbounded, so Gaussian TMLE is appropriate\n", " riesz_penalty=cfg[\"penalty\"],\n", " riesz_lam=cfg[\"lam\"],\n", " riesz_p_norm=cfg.get(\"p_norm\"),\n", " max_iter=250,\n", " tol=1e-8,\n", " )\n", "\n", " row = {\n", " \"generator\": gname,\n", " \"penalty\": cfg[\"penalty\"],\n", " \"lam\": cfg[\"lam\"],\n", " }\n", "\n", " for k in (\"ra\", \"rw\", \"arw\", \"tmle\"):\n", " e = res_i.estimates[k]\n", " row[f\"{k}\"] = e.estimate\n", " row[f\"{k}_se\"] = e.se\n", " row[f\"{k}_err\"] = e.estimate - true_ame0\n", "\n", " rows.append(row)\n", "\n", "import pandas as pd\n", "\n", "df = pd.DataFrame(rows)\n", "# Sort by absolute ARW error (ARW is typically stable)\n", "df = df.sort_values(by=\"arw_err\", key=lambda s: np.abs(s))\n", "display(df)" ] }, { "cell_type": "code", "execution_count": null, "id": "e5625385", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "097ddbf0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.6" } }, "nbformat": 4, "nbformat_minor": 5 }