Panel DID simulation with true value (genriesz)

We implement DID as ATT on the differenced outcome

\[\Delta Y = Y_1 - Y_0,\]

where:

  • \(Y0\) is the pre-period outcome,

  • \(Y1\) is the post-period outcome,

  • the same units are observed in both periods (panel),

  • \(D\) is a binary treatment indicator (treatment happens in the post period).

With a standard panel DID setup:

\[Y_{0} = \mu(Z) + u + \varepsilon_0, \qquad Y_{1} = \mu(Z) + \text{trend}(Z) + u + \tau D + \varepsilon_1,\]

the DID effect equals the constant treatment effect \(\tau\), provided the parallel trends condition holds after conditioning on \(Z\).

This notebook:

  1. simulates a large population to compute an approximate “true” DID effect,

  2. samples a dataset and calls genriesz.grr_did(X, Y0=..., Y1=...).

[1]:
import numpy as np

from genriesz import (
    grr_did,
    SquaredGenerator,
    UKLGenerator,
    BPGenerator,
    PolynomialBasis,
    TreatmentInteractionBasis,
    RBFRandomFourierBasis,
    KNNCatchmentBasis,
)

rng = np.random.default_rng(0)

DGP

[7]:
def draw_panel(n: int, d_z: int, tau: float, seed: int = 0):
    rng = np.random.default_rng(seed)
    Z = rng.normal(size=(n, d_z))

    logits = 0.6 * Z[:, 0] - 0.25 * Z[:, 1]
    e = 1.0 / (1.0 + np.exp(-logits))
    D = rng.binomial(1, e, size=n).astype(int)

    mu = 0.5 * Z[:, 0] - 0.2 * Z[:, 1] ** 2
    trend = 0.5 + 0.1 * Z[:, 0]  # common trend that depends on Z

    u = rng.normal(scale=1.0, size=n)  # unit fixed effect

    Y0 = mu + u + rng.normal(scale=1.0, size=n)
    Y1 = mu + trend + u + tau * D + rng.normal(scale=1.0, size=n)

    X = np.column_stack([D.astype(float), Z])
    return X, Y0, Y1, D

tau_true = 1.0

# Large population for an approximate truth
X_pop, Y0_pop, Y1_pop, D_pop = draw_panel(n=200_000, d_z=5, tau=tau_true, seed=1)

true_did = np.mean((Y1_pop - Y0_pop)[D_pop == 1]) - np.mean((Y1_pop - Y0_pop)[D_pop == 0])  # naive DID
# Our target here is "ATT on ΔY", whose true value equals tau_true by construction.
print("True tau (by construction):", tau_true)
print("Naive DID (difference in mean ΔY):", true_did)

True tau (by construction): 1.0
Naive DID (difference in mean ΔY): 1.0477956893531037

Example 1: Polynomial basis + treatment interactions

[9]:
# Sample a dataset from the same DGP
X, Y0, Y1, D = draw_panel(n=6000, d_z=5, tau=tau_true, seed=0)

psi = PolynomialBasis(degree=2, include_bias=True)
phi = TreatmentInteractionBasis(base_basis=psi)

gen = SquaredGenerator(C=0.0).as_generator()

res = grr_did(
    X=X,
    Y0=Y0,
    Y1=Y1,
    basis=phi,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res.summary_text())

DID estimates (n=6000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.5468841195092563, max_abs_smd_weighted=0.002194706023289376, ess_treated=3017.871638373246, ess_control=2036.5721815918366

Estimator         Estimate            SE                           CI     p-value
---------------------------------------------------------------------------------
RA                0.983907     0.0129607        [ 0.958505,  1.00931]           0
RW                0.976091     0.0495386        [ 0.878997,  1.07318]           0
ARW                0.98479     0.0421389          [ 0.9022,  1.06738]           0
TMLE              0.984776     0.0421433        [ 0.902176,  1.06738]           0

Example 2: RKHS basis (RBF random Fourier features)

[ ]:
psi_rff = RBFRandomFourierBasis(
    n_features=500,
    sigma=1.0,
    standardize=True,
    random_state=0,
)
phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)

res_phi_rff = grr_did(
    X=X,
    Y0=Y0,
    Y1=Y1,
    basis=phi_rff,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_phi_rff.summary_text())

Example 3: KNN catchment basis (nearest-neighbor matching)

Nearest-neighbor matching as a special case of squared-loss Riesz regression.

[ ]:
basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi_knn   = TreatmentInteractionBasis(base_basis=basis_knn)

res_phi_knn = grr_did(
    X=X,
    Y0=Y0,
    Y1=Y1,
    basis=phi_knn,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_phi_knn.summary_text())

Example 4: Random forest leaf basis (optional)

[ ]:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis

rf = RandomForestRegressor(n_estimators=200, max_depth=6, random_state=0)
leaf_basis = RandomForestLeafBasis(rf).fit(X, Y1 - Y0)

res_leaf_basis = grr_did(
    X=X,
    Y0=Y0,
    Y1=Y1,
    basis=leaf_basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_leaf_basis.summary_text())

Example 5: Neural network embedding basis (optional)

[ ]:
import torch
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis

torch.manual_seed(0)
net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)
nn_basis = TorchEmbeddingBasis(net, include_bias=True, device="cpu")

res_nn_basis = grr_did(
    X=X,
    Y0=Y0,
    Y1=Y1,
    basis=nn_basis,
    generator=gen,
    cross_fit=True,
    folds=5,
    random_state=0,
    estimators=("ra", "rw", "arw", "tmle"),
    outcome_models="shared",
    riesz_penalty="l2",
    riesz_lam=1e-3,
    max_iter=300,
    tol=1e-8,
)

print(res_nn_basis.summary_text())

Generator / regularization sweep (SQ / UKL / BP)

We repeat the DID estimation (implemented as ATT on the differenced outcome) under SQ-Riesz / UKL-Riesz / BP-Riesz, multiple regularization norms, and multiple regularization strengths.

For UKL/BP we set a branch function to match the treatment/control sign pattern.

[15]:
branch = lambda x: int(x[0] == 1.0)

generator_grid = [
    ("SQ", SquaredGenerator(C=0.0).as_generator()),
    ("UKL (C=1)", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),
    ("BP (omega=0.1, C=1)", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),
    ("BP (omega=0.2, C=1)", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),
    ("BP (omega=0.5, C=1)", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),
]

penalty_grid = [
    {"penalty": "l2", "lam": 1e-4, "p_norm": None},
    {"penalty": "l2", "lam": 1e-3, "p_norm": None},
    {"penalty": "l1", "lam": 1e-4, "p_norm": None},
    {"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]

rows = []
for gname, gen_i in generator_grid:
    for cfg in penalty_grid:
        res_i = grr_did(
            X=X,
            Y0=Y0,
            Y1=Y1,
            basis=phi,
            generator=gen_i,
            cross_fit=True,
            folds=3,
            random_state=0,
            estimators=("ra", "rw", "arw", "tmle"),
            outcome_models="shared",
            outcome_link="identity",
            riesz_penalty=cfg["penalty"],
            riesz_lam=cfg["lam"],
            riesz_p_norm=cfg.get("p_norm"),
            max_iter=250,
            tol=1e-8,
        )

        row = {
            "generator": gname,
            "penalty": cfg["penalty"],
            "lam": cfg["lam"],
        }
        for k in ("ra", "rw", "arw", "tmle"):
            e = res_i.estimates[k]
            row[f"{k}"] = e.estimate
            row[f"{k}_se"] = e.se
            row[f"{k}_err"] = e.estimate - tau_true
        rows.append(row)

import pandas as pd

df = pd.DataFrame(rows)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
generator penalty lam ra ra_se ra_err rw rw_se rw_err arw arw_se arw_err tmle tmle_se tmle_err
6 UKL (C=1) l1 0.0001 0.983161 0.013075 -0.016839 0.974852 0.051020 -0.025148 0.991554 0.043837 -0.008446 0.990610 0.043874 -0.009390
4 UKL (C=1) l2 0.0001 0.983161 0.013075 -0.016839 0.974713 0.051015 -0.025287 0.991546 0.043832 -0.008454 0.990603 0.043870 -0.009397
5 UKL (C=1) l2 0.0010 0.983161 0.013075 -0.016839 0.975179 0.050895 -0.024821 0.991356 0.043705 -0.008644 0.990484 0.043742 -0.009516
7 UKL (C=1) lp 0.0010 0.983161 0.013075 -0.016839 0.975179 0.050895 -0.024821 0.991356 0.043705 -0.008644 0.990484 0.043742 -0.009516
19 BP (omega=0.5, C=1) lp 0.0010 0.983161 0.013075 -0.016839 0.990781 0.050259 -0.009219 0.991199 0.042990 -0.008801 0.990698 0.043028 -0.009302
17 BP (omega=0.5, C=1) l2 0.0010 0.983161 0.013075 -0.016839 0.990781 0.050259 -0.009219 0.991199 0.042990 -0.008801 0.990698 0.043028 -0.009302
16 BP (omega=0.5, C=1) l2 0.0001 0.983161 0.013075 -0.016839 0.986536 0.050288 -0.013464 0.990830 0.043032 -0.009170 0.990316 0.043068 -0.009684
18 BP (omega=0.5, C=1) l1 0.0001 0.983161 0.013075 -0.016839 0.986272 0.050383 -0.013728 0.990558 0.043127 -0.009442 0.990050 0.043162 -0.009950
11 BP (omega=0.1, C=1) lp 0.0010 0.983161 0.013075 -0.016839 1.009891 0.046637 0.009891 0.986427 0.039662 -0.013573 0.986603 0.039681 -0.013397
9 BP (omega=0.1, C=1) l2 0.0010 0.983161 0.013075 -0.016839 1.009891 0.046637 0.009891 0.986427 0.039662 -0.013573 0.986603 0.039681 -0.013397
8 BP (omega=0.1, C=1) l2 0.0001 0.983161 0.013075 -0.016839 1.009828 0.046649 0.009828 0.986413 0.039671 -0.013587 0.986587 0.039689 -0.013413
10 BP (omega=0.1, C=1) l1 0.0001 0.983161 0.013075 -0.016839 1.009280 0.046642 0.009280 0.986367 0.039666 -0.013633 0.986539 0.039684 -0.013461
3 SQ lp 0.0010 0.983161 0.013075 -0.016839 0.975879 0.049509 -0.024121 0.986051 0.042223 -0.013949 0.986004 0.042238 -0.013996
1 SQ l2 0.0010 0.983161 0.013075 -0.016839 0.975879 0.049509 -0.024121 0.986051 0.042223 -0.013949 0.986004 0.042238 -0.013996
0 SQ l2 0.0001 0.983161 0.013075 -0.016839 0.978916 0.049698 -0.021084 0.985968 0.042369 -0.014032 0.985912 0.042383 -0.014088
2 SQ l1 0.0001 0.983161 0.013075 -0.016839 0.979128 0.049712 -0.020872 0.985944 0.042379 -0.014056 0.985887 0.042394 -0.014113
13 BP (omega=0.2, C=1) l2 0.0010 0.983161 0.013075 -0.016839 1.011729 0.048423 0.011729 0.980771 0.041029 -0.019229 0.980738 0.041015 -0.019262
15 BP (omega=0.2, C=1) lp 0.0010 0.983161 0.013075 -0.016839 1.011729 0.048423 0.011729 0.980771 0.041029 -0.019229 0.980738 0.041015 -0.019262
14 BP (omega=0.2, C=1) l1 0.0001 0.983161 0.013075 -0.016839 1.010995 0.048484 0.010995 0.980671 0.041077 -0.019329 0.980640 0.041062 -0.019360
12 BP (omega=0.2, C=1) l2 0.0001 0.983161 0.013075 -0.016839 1.010725 0.048487 0.010725 0.980661 0.041081 -0.019339 0.980630 0.041066 -0.019370
[ ]: