Panel DID simulation with true value (genriesz)
We implement DID as ATT on the differenced outcome
where:
\(Y0\) is the pre-period outcome,
\(Y1\) is the post-period outcome,
the same units are observed in both periods (panel),
\(D\) is a binary treatment indicator (treatment happens in the post period).
With a standard panel DID setup:
the DID effect equals the constant treatment effect \(\tau\), provided the parallel trends condition holds after conditioning on \(Z\).
This notebook:
simulates a large population to compute an approximate “true” DID effect,
samples a dataset and calls
genriesz.grr_did(X, Y0=..., Y1=...).
[1]:
import numpy as np
from genriesz import (
grr_did,
SquaredGenerator,
UKLGenerator,
BPGenerator,
PolynomialBasis,
TreatmentInteractionBasis,
RBFRandomFourierBasis,
KNNCatchmentBasis,
)
rng = np.random.default_rng(0)
DGP
[7]:
def draw_panel(n: int, d_z: int, tau: float, seed: int = 0):
rng = np.random.default_rng(seed)
Z = rng.normal(size=(n, d_z))
logits = 0.6 * Z[:, 0] - 0.25 * Z[:, 1]
e = 1.0 / (1.0 + np.exp(-logits))
D = rng.binomial(1, e, size=n).astype(int)
mu = 0.5 * Z[:, 0] - 0.2 * Z[:, 1] ** 2
trend = 0.5 + 0.1 * Z[:, 0] # common trend that depends on Z
u = rng.normal(scale=1.0, size=n) # unit fixed effect
Y0 = mu + u + rng.normal(scale=1.0, size=n)
Y1 = mu + trend + u + tau * D + rng.normal(scale=1.0, size=n)
X = np.column_stack([D.astype(float), Z])
return X, Y0, Y1, D
tau_true = 1.0
# Large population for an approximate truth
X_pop, Y0_pop, Y1_pop, D_pop = draw_panel(n=200_000, d_z=5, tau=tau_true, seed=1)
true_did = np.mean((Y1_pop - Y0_pop)[D_pop == 1]) - np.mean((Y1_pop - Y0_pop)[D_pop == 0]) # naive DID
# Our target here is "ATT on ΔY", whose true value equals tau_true by construction.
print("True tau (by construction):", tau_true)
print("Naive DID (difference in mean ΔY):", true_did)
True tau (by construction): 1.0
Naive DID (difference in mean ΔY): 1.0477956893531037
Example 1: Polynomial basis + treatment interactions
[9]:
# Sample a dataset from the same DGP
X, Y0, Y1, D = draw_panel(n=6000, d_z=5, tau=tau_true, seed=0)
psi = PolynomialBasis(degree=2, include_bias=True)
phi = TreatmentInteractionBasis(base_basis=psi)
gen = SquaredGenerator(C=0.0).as_generator()
res = grr_did(
X=X,
Y0=Y0,
Y1=Y1,
basis=phi,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res.summary_text())
DID estimates (n=6000)
alpha=0.05 | null=0.0
diagnostics: max_abs_smd_unweighted=0.5468841195092563, max_abs_smd_weighted=0.002194706023289376, ess_treated=3017.871638373246, ess_control=2036.5721815918366
Estimator Estimate SE CI p-value
---------------------------------------------------------------------------------
RA 0.983907 0.0129607 [ 0.958505, 1.00931] 0
RW 0.976091 0.0495386 [ 0.878997, 1.07318] 0
ARW 0.98479 0.0421389 [ 0.9022, 1.06738] 0
TMLE 0.984776 0.0421433 [ 0.902176, 1.06738] 0
Example 2: RKHS basis (RBF random Fourier features)
[ ]:
psi_rff = RBFRandomFourierBasis(
n_features=500,
sigma=1.0,
standardize=True,
random_state=0,
)
phi_rff = TreatmentInteractionBasis(base_basis=psi_rff)
res_phi_rff = grr_did(
X=X,
Y0=Y0,
Y1=Y1,
basis=phi_rff,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_phi_rff.summary_text())
Example 3: KNN catchment basis (nearest-neighbor matching)
Nearest-neighbor matching as a special case of squared-loss Riesz regression.
[ ]:
basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi_knn = TreatmentInteractionBasis(base_basis=basis_knn)
res_phi_knn = grr_did(
X=X,
Y0=Y0,
Y1=Y1,
basis=phi_knn,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_phi_knn.summary_text())
Example 4: Random forest leaf basis (optional)
[ ]:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis
rf = RandomForestRegressor(n_estimators=200, max_depth=6, random_state=0)
leaf_basis = RandomForestLeafBasis(rf).fit(X, Y1 - Y0)
res_leaf_basis = grr_did(
X=X,
Y0=Y0,
Y1=Y1,
basis=leaf_basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_leaf_basis.summary_text())
Example 5: Neural network embedding basis (optional)
[ ]:
import torch
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis
torch.manual_seed(0)
net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64,), output_dim=32)
nn_basis = TorchEmbeddingBasis(net, include_bias=True, device="cpu")
res_nn_basis = grr_did(
X=X,
Y0=Y0,
Y1=Y1,
basis=nn_basis,
generator=gen,
cross_fit=True,
folds=5,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
riesz_penalty="l2",
riesz_lam=1e-3,
max_iter=300,
tol=1e-8,
)
print(res_nn_basis.summary_text())
Generator / regularization sweep (SQ / UKL / BP)
We repeat the DID estimation (implemented as ATT on the differenced outcome) under SQ-Riesz / UKL-Riesz / BP-Riesz, multiple regularization norms, and multiple regularization strengths.
For UKL/BP we set a branch function to match the treatment/control sign pattern.
[15]:
branch = lambda x: int(x[0] == 1.0)
generator_grid = [
("SQ", SquaredGenerator(C=0.0).as_generator()),
("UKL (C=1)", UKLGenerator(C=1.0, branch_fn=branch).as_generator()),
("BP (omega=0.1, C=1)", BPGenerator(C=1.0, omega=0.1, branch_fn=branch).as_generator()),
("BP (omega=0.2, C=1)", BPGenerator(C=1.0, omega=0.2, branch_fn=branch).as_generator()),
("BP (omega=0.5, C=1)", BPGenerator(C=1.0, omega=0.5, branch_fn=branch).as_generator()),
]
penalty_grid = [
{"penalty": "l2", "lam": 1e-4, "p_norm": None},
{"penalty": "l2", "lam": 1e-3, "p_norm": None},
{"penalty": "l1", "lam": 1e-4, "p_norm": None},
{"penalty": "lp", "lam": 1e-3, "p_norm": 1.5},
]
rows = []
for gname, gen_i in generator_grid:
for cfg in penalty_grid:
res_i = grr_did(
X=X,
Y0=Y0,
Y1=Y1,
basis=phi,
generator=gen_i,
cross_fit=True,
folds=3,
random_state=0,
estimators=("ra", "rw", "arw", "tmle"),
outcome_models="shared",
outcome_link="identity",
riesz_penalty=cfg["penalty"],
riesz_lam=cfg["lam"],
riesz_p_norm=cfg.get("p_norm"),
max_iter=250,
tol=1e-8,
)
row = {
"generator": gname,
"penalty": cfg["penalty"],
"lam": cfg["lam"],
}
for k in ("ra", "rw", "arw", "tmle"):
e = res_i.estimates[k]
row[f"{k}"] = e.estimate
row[f"{k}_se"] = e.se
row[f"{k}_err"] = e.estimate - tau_true
rows.append(row)
import pandas as pd
df = pd.DataFrame(rows)
df = df.sort_values(by="arw_err", key=lambda s: np.abs(s))
display(df)
| generator | penalty | lam | ra | ra_se | ra_err | rw | rw_se | rw_err | arw | arw_se | arw_err | tmle | tmle_se | tmle_err | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 6 | UKL (C=1) | l1 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 0.974852 | 0.051020 | -0.025148 | 0.991554 | 0.043837 | -0.008446 | 0.990610 | 0.043874 | -0.009390 |
| 4 | UKL (C=1) | l2 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 0.974713 | 0.051015 | -0.025287 | 0.991546 | 0.043832 | -0.008454 | 0.990603 | 0.043870 | -0.009397 |
| 5 | UKL (C=1) | l2 | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 0.975179 | 0.050895 | -0.024821 | 0.991356 | 0.043705 | -0.008644 | 0.990484 | 0.043742 | -0.009516 |
| 7 | UKL (C=1) | lp | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 0.975179 | 0.050895 | -0.024821 | 0.991356 | 0.043705 | -0.008644 | 0.990484 | 0.043742 | -0.009516 |
| 19 | BP (omega=0.5, C=1) | lp | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 0.990781 | 0.050259 | -0.009219 | 0.991199 | 0.042990 | -0.008801 | 0.990698 | 0.043028 | -0.009302 |
| 17 | BP (omega=0.5, C=1) | l2 | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 0.990781 | 0.050259 | -0.009219 | 0.991199 | 0.042990 | -0.008801 | 0.990698 | 0.043028 | -0.009302 |
| 16 | BP (omega=0.5, C=1) | l2 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 0.986536 | 0.050288 | -0.013464 | 0.990830 | 0.043032 | -0.009170 | 0.990316 | 0.043068 | -0.009684 |
| 18 | BP (omega=0.5, C=1) | l1 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 0.986272 | 0.050383 | -0.013728 | 0.990558 | 0.043127 | -0.009442 | 0.990050 | 0.043162 | -0.009950 |
| 11 | BP (omega=0.1, C=1) | lp | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 1.009891 | 0.046637 | 0.009891 | 0.986427 | 0.039662 | -0.013573 | 0.986603 | 0.039681 | -0.013397 |
| 9 | BP (omega=0.1, C=1) | l2 | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 1.009891 | 0.046637 | 0.009891 | 0.986427 | 0.039662 | -0.013573 | 0.986603 | 0.039681 | -0.013397 |
| 8 | BP (omega=0.1, C=1) | l2 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 1.009828 | 0.046649 | 0.009828 | 0.986413 | 0.039671 | -0.013587 | 0.986587 | 0.039689 | -0.013413 |
| 10 | BP (omega=0.1, C=1) | l1 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 1.009280 | 0.046642 | 0.009280 | 0.986367 | 0.039666 | -0.013633 | 0.986539 | 0.039684 | -0.013461 |
| 3 | SQ | lp | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 0.975879 | 0.049509 | -0.024121 | 0.986051 | 0.042223 | -0.013949 | 0.986004 | 0.042238 | -0.013996 |
| 1 | SQ | l2 | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 0.975879 | 0.049509 | -0.024121 | 0.986051 | 0.042223 | -0.013949 | 0.986004 | 0.042238 | -0.013996 |
| 0 | SQ | l2 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 0.978916 | 0.049698 | -0.021084 | 0.985968 | 0.042369 | -0.014032 | 0.985912 | 0.042383 | -0.014088 |
| 2 | SQ | l1 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 0.979128 | 0.049712 | -0.020872 | 0.985944 | 0.042379 | -0.014056 | 0.985887 | 0.042394 | -0.014113 |
| 13 | BP (omega=0.2, C=1) | l2 | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 1.011729 | 0.048423 | 0.011729 | 0.980771 | 0.041029 | -0.019229 | 0.980738 | 0.041015 | -0.019262 |
| 15 | BP (omega=0.2, C=1) | lp | 0.0010 | 0.983161 | 0.013075 | -0.016839 | 1.011729 | 0.048423 | 0.011729 | 0.980771 | 0.041029 | -0.019229 | 0.980738 | 0.041015 | -0.019262 |
| 14 | BP (omega=0.2, C=1) | l1 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 1.010995 | 0.048484 | 0.010995 | 0.980671 | 0.041077 | -0.019329 | 0.980640 | 0.041062 | -0.019360 |
| 12 | BP (omega=0.2, C=1) | l2 | 0.0001 | 0.983161 | 0.013075 | -0.016839 | 1.010725 | 0.048487 | 0.010725 | 0.980661 | 0.041081 | -0.019339 | 0.980630 | 0.041066 | -0.019370 |
[ ]: