Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ python benchmarks/run_benchmarks.py --all

# Run specific estimator
python benchmarks/run_benchmarks.py --estimator callaway
python benchmarks/run_benchmarks.py --estimator multiperiod
```

See `docs/benchmarks.rst` for full methodology and validation results.
Expand Down
6 changes: 4 additions & 2 deletions METHODOLOGY_REVIEW.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ Each estimator in diff-diff should be periodically reviewed to ensure:
fixed to use interaction sub-VCV instead of full regression VCV.

**Outstanding Concerns:**
- No R comparison benchmarks yet (unlike DifferenceInDifferences and CallawaySantAnna which
have formal R benchmark tests). Consider adding `benchmarks/R/multiperiod_benchmark.R`.
- ~~No R comparison benchmarks yet~~ — **Resolved**: R comparison benchmark added via
`benchmarks/R/benchmark_multiperiod.R` using `fixest::feols(outcome ~ treated * time_f | unit)`.
Results match R exactly: ATT diff < 1e-11, SE diff 0.0%, period effects correlation 1.0.
Validated at small (200 units) and 1k scales.
- Default SE is HC1 (not cluster-robust at unit level as fixest uses). Cluster-robust
available via `cluster` parameter but not the default.
- Endpoint binning for distant event times not yet implemented.
Expand Down
201 changes: 201 additions & 0 deletions benchmarks/R/benchmark_multiperiod.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#!/usr/bin/env Rscript
# Benchmark: MultiPeriodDiD event study (R `fixest` package)
#
# Usage:
# Rscript benchmark_multiperiod.R --data path/to/data.csv --output path/to/results.json \
# --n-pre 4 --n-post 4

library(fixest)
library(jsonlite)
library(data.table)

# Parse command line arguments
args <- commandArgs(trailingOnly = TRUE)

parse_args <- function(args) {
result <- list(
data = NULL,
output = NULL,
cluster = "unit",
n_pre = NULL,
n_post = NULL,
reference_period = NULL
)

i <- 1
while (i <= length(args)) {
if (args[i] == "--data") {
result$data <- args[i + 1]
i <- i + 2
} else if (args[i] == "--output") {
result$output <- args[i + 1]
i <- i + 2
} else if (args[i] == "--cluster") {
result$cluster <- args[i + 1]
i <- i + 2
} else if (args[i] == "--n-pre") {
result$n_pre <- as.integer(args[i + 1])
i <- i + 2
} else if (args[i] == "--n-post") {
result$n_post <- as.integer(args[i + 1])
i <- i + 2
} else if (args[i] == "--reference-period") {
result$reference_period <- as.integer(args[i + 1])
i <- i + 2
} else {
i <- i + 1
}
}

if (is.null(result$data) || is.null(result$output)) {
stop("Usage: Rscript benchmark_multiperiod.R --data <path> --output <path> --n-pre <int> --n-post <int>")
}
if (is.null(result$n_pre) || is.null(result$n_post)) {
stop("--n-pre and --n-post are required")
}

# Default reference period: last pre-period
if (is.null(result$reference_period)) {
result$reference_period <- result$n_pre
}

return(result)
}

config <- parse_args(args)

# Load data
message(sprintf("Loading data from: %s", config$data))
data <- fread(config$data)

ref_period <- config$reference_period
message(sprintf("Reference period: %d", ref_period))
message(sprintf("n_pre: %d, n_post: %d", config$n_pre, config$n_post))

# Create factor for time with reference level
data[, time_f := relevel(factor(time), ref = as.character(ref_period))]

# Run benchmark
message("Running MultiPeriodDiD estimation (fixest::feols)...")
start_time <- Sys.time()

# Regression: outcome ~ treated * time_f | unit, clustered SEs
# With | unit, fixest absorbs unit fixed effects. The unit-invariant 'treated'
# main effect is collinear with unit FE and is absorbed automatically.
# Interaction coefficients treated:time_fK remain identified.
cluster_formula <- as.formula(paste0("~", config$cluster))
model <- feols(outcome ~ treated * time_f | unit, data = data, cluster = cluster_formula)

estimation_time <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))

# Extract all coefficients and SEs
coefs <- coef(model)
ses <- se(model)
vcov_mat <- vcov(model)

# Extract interaction coefficients (treated:time_fK for each non-reference K)
interaction_mask <- grepl("^treated:time_f", names(coefs))
interaction_names <- names(coefs)[interaction_mask]
interaction_coefs <- coefs[interaction_mask]
interaction_ses <- ses[interaction_mask]

message(sprintf("Found %d interaction coefficients", length(interaction_names)))

# Build period effects list
all_periods <- sort(unique(data$time))
period_effects <- list()

for (i in seq_along(interaction_names)) {
coef_name <- interaction_names[i]
# Extract period value from coefficient name "treated:time_fK"
period_val <- as.integer(sub("treated:time_f", "", coef_name))
event_time <- period_val - ref_period

period_effects[[i]] <- list(
period = period_val,
event_time = event_time,
att = unname(interaction_coefs[i]),
se = unname(interaction_ses[i])
)
}

# Compute average ATT across post-periods (covariance-aware SE)
post_period_names <- c()
for (coef_name in interaction_names) {
period_val <- as.integer(sub("treated:time_f", "", coef_name))
if (period_val > config$n_pre) {
post_period_names <- c(post_period_names, coef_name)
}
}

n_post_periods <- length(post_period_names)
message(sprintf("Post-period interaction coefficients: %d", n_post_periods))

if (n_post_periods > 0) {
avg_att <- mean(coefs[post_period_names])
vcov_sub <- vcov_mat[post_period_names, post_period_names, drop = FALSE]
avg_se <- sqrt(sum(vcov_sub) / n_post_periods^2)
# NaN guard: match registry convention (REGISTRY.md lines 179-183)
if (is.finite(avg_se) && avg_se > 0) {
avg_t <- avg_att / avg_se
avg_pval <- 2 * pt(abs(avg_t), df = model$nobs - length(coefs), lower.tail = FALSE)
avg_ci_lower <- avg_att - qt(0.975, df = model$nobs - length(coefs)) * avg_se
avg_ci_upper <- avg_att + qt(0.975, df = model$nobs - length(coefs)) * avg_se
} else {
avg_t <- NA
avg_pval <- NA
avg_ci_lower <- NA
avg_ci_upper <- NA
}
} else {
avg_att <- NA
avg_se <- NA
avg_pval <- NA
avg_ci_lower <- NA
avg_ci_upper <- NA
}

message(sprintf("Average ATT: %.6f", avg_att))
message(sprintf("Average SE: %.6f", avg_se))

# Format output
results <- list(
estimator = "fixest::feols (multiperiod)",
cluster = config$cluster,

# Average treatment effect
att = avg_att,
se = avg_se,
pvalue = avg_pval,
ci_lower = avg_ci_lower,
ci_upper = avg_ci_upper,

# Reference period
reference_period = ref_period,

# Period-level effects
period_effects = period_effects,

# Timing
timing = list(
estimation_seconds = estimation_time,
total_seconds = estimation_time
),

# Metadata
metadata = list(
r_version = R.version.string,
fixest_version = as.character(packageVersion("fixest")),
n_units = length(unique(data$unit)),
n_periods = length(unique(data$time)),
n_obs = nrow(data),
n_pre = config$n_pre,
n_post = config$n_post
)
)

# Write output
message(sprintf("Writing results to: %s", config$output))
write_json(results, config$output, auto_unbox = TRUE, pretty = TRUE, digits = 10)

message(sprintf("Completed in %.3f seconds", estimation_time))
162 changes: 162 additions & 0 deletions benchmarks/python/benchmark_multiperiod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""
Benchmark: MultiPeriodDiD event study (diff-diff MultiPeriodDiD).

Usage:
python benchmark_multiperiod.py --data path/to/data.csv --output path/to/results.json \
--n-pre 4 --n-post 4
"""

import argparse
import json
import os
import sys
from pathlib import Path

# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff
# This ensures the backend configuration is respected by all modules
def _get_backend_from_args():
"""Parse --backend argument without importing diff_diff."""
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"])
args, _ = parser.parse_known_args()
return args.backend

_requested_backend = _get_backend_from_args()
if _requested_backend in ("python", "rust"):
os.environ["DIFF_DIFF_BACKEND"] = _requested_backend

# NOW import diff_diff and other dependencies (will see the env var)
import pandas as pd

# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from diff_diff import MultiPeriodDiD, HAS_RUST_BACKEND
from benchmarks.python.utils import Timer


def parse_args():
parser = argparse.ArgumentParser(description="Benchmark MultiPeriodDiD estimator")
parser.add_argument("--data", required=True, help="Path to input CSV data")
parser.add_argument("--output", required=True, help="Path to output JSON results")
parser.add_argument(
"--cluster", default="unit", help="Column to cluster standard errors on"
)
parser.add_argument(
"--n-pre", type=int, required=True, help="Number of pre-treatment periods"
)
parser.add_argument(
"--n-post", type=int, required=True, help="Number of post-treatment periods"
)
parser.add_argument(
"--reference-period", type=int, default=None,
help="Reference period (default: last pre-period = n_pre)"
)
parser.add_argument(
"--backend", default="auto", choices=["auto", "python", "rust"],
help="Backend to use: auto (default), python (pure Python), rust (Rust backend)"
)
return parser.parse_args()


def get_actual_backend() -> str:
"""Return the actual backend being used based on HAS_RUST_BACKEND."""
return "rust" if HAS_RUST_BACKEND else "python"


def main():
args = parse_args()

# Get actual backend (already configured via env var before imports)
actual_backend = get_actual_backend()
print(f"Using backend: {actual_backend}")

# Load data
print(f"Loading data from: {args.data}")
data = pd.read_csv(args.data)

# Compute post_periods and reference_period from args
all_periods = sorted(data["time"].unique())
n_pre = args.n_pre
post_periods = [p for p in all_periods if p > n_pre]
ref_period = args.reference_period if args.reference_period is not None else n_pre

print(f"All periods: {all_periods}")
print(f"Post periods: {post_periods}")
print(f"Reference period: {ref_period}")

# Run benchmark
print("Running MultiPeriodDiD estimation...")

did = MultiPeriodDiD(robust=True, cluster=args.cluster)

with Timer() as timer:
results = did.fit(
data,
outcome="outcome",
treatment="treated",
time="time",
post_periods=post_periods,
reference_period=ref_period,
absorb=["unit"],
)

total_time = timer.elapsed

# Extract period effects (excluding reference period)
period_effects = []
for period, pe in sorted(results.period_effects.items()):
event_time = period - ref_period
period_effects.append({
"period": int(period),
"event_time": int(event_time),
"att": float(pe.effect),
"se": float(pe.se),
})

# Build output
output = {
"estimator": "diff_diff.MultiPeriodDiD",
"backend": actual_backend,
"cluster": args.cluster,
# Average treatment effect (across post-periods)
"att": float(results.avg_att),
"se": float(results.avg_se),
"pvalue": float(results.avg_p_value),
"ci_lower": float(results.avg_conf_int[0]),
"ci_upper": float(results.avg_conf_int[1]),
# Reference period
"reference_period": int(ref_period),
# Period-level effects
"period_effects": period_effects,
# Timing
"timing": {
"estimation_seconds": total_time,
"total_seconds": total_time,
},
# Metadata
"metadata": {
"n_units": int(data["unit"].nunique()),
"n_periods": int(data["time"].nunique()),
"n_obs": len(data),
"n_pre": n_pre,
"n_post": len(post_periods),
},
}

# Write output
print(f"Writing results to: {args.output}")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(output, f, indent=2)

print(f"ATT: {results.avg_att:.6f}")
print(f"SE: {results.avg_se:.6f}")
print(f"Completed in {total_time:.3f} seconds")
return output


if __name__ == "__main__":
main()
Loading