User guide
This guide summarises how to use Generalized Riesz Regression (GRR) in this
package. The main entry point is genriesz.grr_functional().
Conceptual procedure
To estimate a target parameter \(\theta\) written as a linear functional of the outcome regression \(\gamma(x) = \mathbb{E}[Y\mid X=x]\), you provide:
a functional
m(either a built-ingenriesz.LinearFunctionalor a plain callablem(x_row, gamma)),a feature map (basis)
phi(X), andeither a Bregman generator object
generatoror a generator functiong.
If you pass a plain callable m, genriesz.grr_functional() wraps it as
genriesz.CallableFunctional. The callable must be linear in the
function argument gamma.
The package then:
builds the link function induced by the generator (automatic regressor balancing),
fits a Riesz representer model \(\hat\alpha(x)\),
optionally fits an outcome model \(\hat\gamma(x)\),
reports RA, RW, ARW, and TMLE estimates with standard errors, confidence intervals, and p-values.
Bases
All bases in this library implement the same interface:
batched input:
basis(X)withX.shape == (n, d)returns(n, p),single-row input:
basis(x)withx.shape == (d,)returns(p,),basis.fit(X, y=None)fits any data-dependent parameters (centers, standardisation, etc.),basis.copy()returns a deep copy (used internally by cross-fitting).
Bases passed to genriesz.grr_functional() (and all convenience wrappers) are
copied and re-fitted inside each cross-fitting fold. You do not need to call
fit manually before passing a basis to a grr_* function.
Polynomial
Use genriesz.PolynomialBasis for polynomial expansions up to a given total
degree.
from genriesz import PolynomialBasis
psi = PolynomialBasis(degree=2, include_bias=True)
psi.fit(X) # only needed when calling the basis directly
Phi = psi(X) # (n, p)
PolynomialBasis implements derivative() analytically and is compatible with
genriesz.grr_ame().
Treatment interactions
For binary-treatment causal estimands (ATE, ATT, and DID), interact a base basis \(\psi(Z)\) with the treatment \(D\):
Use genriesz.TreatmentInteractionBasis:
from genriesz import PolynomialBasis, TreatmentInteractionBasis
psi = PolynomialBasis(degree=2, include_bias=True) # base basis on Z
phi = TreatmentInteractionBasis(base_basis=psi) # [D*psi(Z), (1-D)*psi(Z)]
TreatmentInteractionBasis delegates derivative() to the base basis, so it
inherits AME compatibility from the base.
RKHS bases (RBF kernel)
Three classes approximate the Gaussian (RBF-kernel) RKHS:
genriesz.RBFRandomFourierBasis— random Fourier features (Rahimi & Recht),genriesz.RBFNystromBasis— Nyström feature map with eigendecomposition whitening,genriesz.GaussianRKHSBasis— explicit kernel feature map (one basis function per center).
from genriesz import RBFRandomFourierBasis, RBFNystromBasis, GaussianRKHSBasis
rff = RBFRandomFourierBasis(n_features=500, sigma=1.0, standardize=True, random_state=0)
nys = RBFNystromBasis(n_centers=300, sigma=1.0, standardize=True, random_state=0)
krn = GaussianRKHSBasis(n_centers=300, sigma=1.0, standardize=True, random_state=0)
Note
genriesz.RBFRandomFourierBasis implements derivative() analytically
(the derivative of \(\cos(\omega^\top x + b)\) is \(-\omega \sin(\ldots)\)),
and is therefore compatible with genriesz.grr_ame().
genriesz.RBFNystromBasis and genriesz.GaussianRKHSBasis do not
currently implement derivative() and cannot be used with grr_ame.
kNN nearest-neighbor indicator basis
Nearest-neighbor matching can be interpreted as a squared-loss Riesz (LSIF) construction with a nearest-neighbor indicator basis:
where \(\{c_j\}\) are fitted centers and \(\mathrm{NN}_k(z)\) denotes the \(k\)-nearest centers of \(z\).
For ATE, ATT, and DID, combine with genriesz.TreatmentInteractionBasis:
import numpy as np
from genriesz import (
KNNCatchmentBasis, TreatmentInteractionBasis,
grr_ate, SquaredGenerator,
)
gen = SquaredGenerator(C=0.0).as_generator()
# The basis is fitted on the training fold inside each cross-fitting iteration.
basis_knn = KNNCatchmentBasis(n_neighbors=5, include_bias=False)
phi = TreatmentInteractionBasis(base_basis=basis_knn)
res = grr_ate(X=X, Y=Y, basis=phi, generator=gen,
cross_fit=True, folds=5)
When cross-fitting, each training fold becomes the set of centers, so the feature dimension \(p = n_\text{train}\). The internal outcome model uses a dual (Woodbury) solve when \(p > n_\text{test}\), keeping cost at \(O(n^3)\).
Note
KNNCatchmentBasis is piecewise-constant and does not implement
derivative(). It cannot be used with genriesz.grr_ame().
Random forest leaf basis
genriesz.sklearn_basis.RandomForestLeafBasis turns any scikit-learn
RandomForest* model into a feature map via one-hot leaf indicators.
from sklearn.ensemble import RandomForestRegressor
from genriesz import TreatmentInteractionBasis
from genriesz.sklearn_basis import RandomForestLeafBasis
rf = RandomForestRegressor(n_estimators=200, max_depth=6, random_state=0)
psi = RandomForestLeafBasis(rf) # include_bias=False, normalize=True (defaults)
phi = TreatmentInteractionBasis(base_basis=psi)
This keeps GRR convex (linear in parameters) while using a nonparametric partition of the covariate space.
Note
RandomForestLeafBasis normalizes each leaf-indicator row by
\(1/\!\sqrt{T}\) (where \(T\) is the number of trees) by default
(normalize=True). Without this, the row \(\ell_2\)-norm grows as
\(\sqrt{T}\), making the effective regularisation scale tree-count dependent.
Tree-based bases are piecewise-constant and do not implement
derivative(). They cannot be used with genriesz.grr_ame().
Neural network feature maps (PyTorch)
If you install PyTorch (optional), you can use a neural network as a basis function
via genriesz.torch_basis.TorchEmbeddingBasis.
Important
If you train the neural network jointly inside the GRR objective, you leave the convex (GLM) setting. The recommended approach is:
train the embedding network separately (e.g., supervised pre-training),
call
phi.fit(X, y, ...)to run the optional training loop,use the frozen embedding as a feature map in GRR.
from genriesz.torch_basis import MLPEmbeddingNet, TorchEmbeddingBasis
net = MLPEmbeddingNet(input_dim=X.shape[1], hidden_dims=(64, 32), output_dim=16)
phi = TorchEmbeddingBasis(net=net, include_bias=True)
phi.fit(X, y, epochs=10, lr=1e-3, verbose=True) # optional supervised pretraining
Phi = phi(X) # (n, 17) — 16 embedding dims + 1 bias
Note
TorchEmbeddingBasis does not implement derivative() and cannot be used
with genriesz.grr_ame() without additional autograd integration.
AME-compatible bases summary
genriesz.grr_ame() requires basis.derivative(X, coordinate).
Basis |
AME-compatible |
Note |
|---|---|---|
Yes |
Analytical derivative |
|
Yes |
Analytical derivative of \(\cos(\omega^\top x + b)\) |
|
Depends on base |
Delegates to base basis |
|
No |
Piecewise-constant |
|
No |
Piecewise-constant |
|
|
No |
No autograd bridge |
Density ratio estimation
The function genriesz.fit_density_ratio() estimates the covariate-shift
density ratio
from two samples X_num ~ p and X_den ~ q.
You choose a generator (either a built-in name like
generator='ukl'or a customg).The function constructs the corresponding link function \((\partial g)^{-1}\) automatically.
It fits the ratio by minimising the induced convex objective.
By default a Gaussian-kernel RKHS basis is used. You can optionally cross-validate
the bandwidth sigma and regularisation lam.
from genriesz import fit_density_ratio
res = fit_density_ratio(
X_num,
X_den,
generator="ukl", # "sq", "bkl", "bp", "pu", or a generator instance
n_centers=200,
cv=True,
folds=5,
sigma_grid=[0.1, 0.3, 1.0, 3.0],
lam_grid=[1e-3, 1e-2, 1e-1],
random_state=0,
)
r_hat = res.predict_ratio(X_test)
Important
For the squared generator (generator='sq', i.e. genriesz.SquaredGenerator),
the fit uses a closed-form ridge solution. For the binary KL generator,
it uses classification-based density ratio estimation. For all other generators,
a numerical optimizer (L-BFGS-B) is used.
Generators and automatic links
A Bregman generator defines both the loss and the induced link function used for automatic regressor balancing.
The easiest option is to use one of the built-in generator objects:
genriesz.SquaredGenerator(SQ-Riesz)genriesz.UKLGenerator(UKL-Riesz)genriesz.BKLGenerator(BKL-Riesz)genriesz.BPGenerator(BP-Riesz)genriesz.PUGenerator(PU-Riesz)
Call .as_generator() to get a genriesz.BregmanGenerator instance:
from genriesz import SquaredGenerator, UKLGenerator
gen_sq = SquaredGenerator(C=0.0).as_generator()
# UKL/BP with a branch function: use + branch for treated, - for control.
branch = lambda x: int(x[0] == 1.0)
gen_ukl = UKLGenerator(C=1.0, branch_fn=branch).as_generator()
You can also define a completely custom generator:
from genriesz import BregmanGenerator
gen = BregmanGenerator(g=g, grad=grad_g, inv_grad=inv_grad_g, grad2=grad2_g)
Custom generator call signatures
A custom generator can be regressor-dependent:
g(x, alpha)optional:
grad_g(x, alpha)(first derivative w.r.t.alpha)optional:
inv_grad_g(x, v)(inverse derivative map)optional:
grad2_g(x, alpha)(second derivative w.r.t.alpha)
If derivatives are omitted, the package falls back to finite differences and a Newton solver.
Important
For speed and numerical stability, providing inv_grad_g (and ideally
grad_g and grad2_g) is strongly recommended for custom generators.
Branch functions for UKL and BP
genriesz.UKLGenerator and genriesz.BPGenerator are defined on
\(\{|\alpha| > C\}\), which has two branches (positive and negative). When the
sign of \(\alpha(x)\) is known in advance (e.g., positive for treated, negative
for control in ATE/ATT/DID), pass a branch_fn to avoid ambiguity:
branch = lambda x: int(x[0] == 1.0) # 1 = positive branch (treated), 0 = negative
gen = UKLGenerator(C=1.0, branch_fn=branch).as_generator()
Without branch_fn, \(\text{sign}(v)\) is used, which is correct only when
\(|\alpha| > C + 1\). A UserWarning is raised if branch_fn is
omitted.
Estimators, cross fitting, and outcome models
The high-level function genriesz.grr_functional() can report multiple estimators
at once via estimators=(...):
"ra": regression adjustment (plug-in)"rw": Riesz weighting (weighting only)"arw": augmented Riesz weighting (doubly robust)"tmle": targeted maximum likelihood estimation (one-step fluctuation)
Cross fitting
Set cross_fit=True (default) to enable K-fold cross fitting. The number of
folds is controlled by folds (default 5).
With cross fitting:
the data are split into
foldstraining and test splits,all nuisance models (Riesz representer and outcome regression) are fitted on the training fold and evaluated on the test fold,
estimates are computed from the full-sample out-of-fold predictions.
This removes the overfitting bias that would arise if the nuisance models were evaluated on the same data used to fit them.
Outcome models
For RA, ARW, and TMLE you need an outcome regression \(\hat\gamma\). Control it
via outcome_models:
"shared"— use the same basis (and penalty settings) as the Riesz model. Whenriesz_method='grr'(the default), the already-fitted Riesz basis is reused directly — no second fit is performed. This is important for stochastic bases (RBF, Nyström, KNN) where a separate fit would draw different random features."separate"— use a user-providedoutcome_basis."both"— fit both and report both versions."none"— skip outcome modeling (then only"rw"is available).
The outcome link function is specified by outcome_link ("identity" or
"logit"). TMLE infers the likelihood from this link:
outcome_link="identity"⟹ Gaussian targetingoutcome_link="logit"⟹ Bernoulli targeting
Regularization: \(\ell_p\)
For the Riesz model, set:
riesz_penalty="l2"for ridge,riesz_penalty="l1"for lasso,riesz_penalty="lp"withriesz_p_norm=pfor general \(p \ge 1\),shorthand:
riesz_penalty="l1.5"is equivalent to"lp"withp_norm=1.5.
The outcome model (linear or logistic regression on a basis) supports the same
interface via outcome_penalty and outcome_p_norm.
When the basis dimension \(p\) exceeds the number of training observations \(n\), the ridge outcome regression automatically switches to a dual (Woodbury kernel-ridge) solve:
which costs \(O(n^3)\) instead of \(O(p^3)\). This is triggered
automatically for outcome_penalty="l2" ("ridge").
Diagnostics and balance checks
After estimation, result.diagnostics is populated with several quantities.
See diagnostics for a detailed guide.
Alpha tail statistics (all estimands)
alpha_abs_mean— \(n^{-1}\sum_i |\hat\alpha(X_i)|\)alpha_abs_p95— 95th percentile of \(|\hat\alpha(X_i)|\)alpha_abs_max— \(\max_i |\hat\alpha(X_i)|\)
Large values (relative to the outcome scale) suggest that the estimated Riesz
representer has heavy tails. This inflates the variance of the RW and ARW estimators
and may indicate that riesz_lam should be increased or the basis reconsidered.
Covariate balance (ATE, ATT, and DID)
max_abs_smd_unweighted— maximum absolute standardised mean difference before weightingmax_abs_smd_weighted— maximum absolute SMD after weightingess_treated,ess_control— Kish effective sample sizes
A graphical balance check (Love plot) is available via genriesz.FunctionalEstimate.love_plot():
fig, ax = result.love_plot()