#!/usr/bin/env python3
"""
Enhanced benchmark to compare performance between if-elif chains and match-case pattern matching
for AST node inspection in pytest's assertion rewriter, with hot path analysis.
"""
import ast
import timeit
import statistics
import time
from typing import List, Tuple, Dict
import sys
import gc
import json
from datetime import datetime
# Sample Python code templates for testing
TEST_TEMPLATES = {
"simple": '''
"""Simple module docstring."""
def test_function():
assert True
''',
"with_future": '''
"""Module with future imports."""
from __future__ import annotations
from __future__ import print_function
def test_function():
assert True
''',
"no_docstring": '''
from __future__ import annotations
import sys
def test_function():
assert True
''',
"complex": '''
"""Complex module with multiple imports and functions."""
from __future__ import annotations
from __future__ import division
from __future__ import print_function
import os
import sys
from typing import List, Dict
class TestClass:
def test_method(self):
assert True
def test_function():
assert 1 + 1 == 2
''',
"rewrite_disabled": '''
"""Module with PYTEST_DONT_REWRITE marker."""
from __future__ import annotations
def test_function():
assert True
''',
"many_imports": '''
"""Module with many future imports."""
from __future__ import annotations
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from __future__ import generators
from __future__ import nested_scopes
import os
import sys
def test_function():
assert True
''' + '\n'.join([f'from __future__ import annotations' for _ in range(10)])
}
def original_implementation(mod: ast.Module) -> Tuple[int, bool, bool]:
"""Original if-elif implementation from pytest."""
if not mod.body:
return 0, False, False
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
found_disabled = False
pos = 0
for item in mod.body:
if (
expect_docstring
and isinstance(item, ast.Expr)
and isinstance(item.value, ast.Constant)
and isinstance(item.value.value, str)
):
doc = item.value.value
if "PYTEST_DONT_REWRITE" in doc:
found_disabled = True
return pos, found_disabled, expect_docstring
expect_docstring = False
elif (
isinstance(item, ast.ImportFrom)
and item.level == 0
and item.module == "__future__"
):
pass
else:
break
pos += 1
return pos, found_disabled, expect_docstring
def match_case_implementation(mod: ast.Module) -> Tuple[int, bool, bool]:
"""New match-case implementation."""
if not mod.body:
return 0, False, False
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
found_disabled = False
pos = 0
for item in mod.body:
match item:
case ast.Expr(value=ast.Constant(value=str(doc))) if expect_docstring:
if "PYTEST_DONT_REWRITE" in doc:
found_disabled = True
return pos, found_disabled, expect_docstring
expect_docstring = False
case ast.ImportFrom(level=0, module="__future__"):
pass
case _:
break
pos += 1
return pos, found_disabled, expect_docstring
def warmup_function(func, iterations=1000):
"""Warm up a function to ensure it's in cache and potentially optimized."""
for _ in range(iterations):
func()
def measure_with_warmup(func, iterations=10000, warmup_iterations=5000):
"""Measure performance after warming up the function."""
# Warmup phase
warmup_function(func, warmup_iterations)
# Actual measurement
start = time.perf_counter()
for _ in range(iterations):
func()
end = time.perf_counter()
return end - start
def run_hot_path_analysis(code_template: str, runs: int = 10, iterations_per_run: int = 10000):
"""Analyze performance over multiple runs to see hot path optimization."""
tree = ast.parse(code_template)
original_times = []
match_times = []
for run in range(runs):
# Force garbage collection between runs for consistency
gc.collect()
# Measure original implementation
orig_time = measure_with_warmup(
lambda: original_implementation(tree),
iterations_per_run,
warmup_iterations=1000 if run == 0 else 100
)
original_times.append(orig_time)
# Measure match-case implementation
match_time = measure_with_warmup(
lambda: match_case_implementation(tree),
iterations_per_run,
warmup_iterations=1000 if run == 0 else 100
)
match_times.append(match_time)
return original_times, match_times
def analyze_convergence(times: List[float], name: str) -> Dict:
"""Analyze how performance converges over runs."""
analysis = {
'name': name,
'initial': times[0] if times else 0,
'final': times[-1] if times else 0,
'mean': statistics.mean(times) if times else 0,
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
'min': min(times) if times else 0,
'max': max(times) if times else 0,
'convergence_rate': 0,
'stabilized_at_run': 0
}
if len(times) > 1:
# Calculate convergence rate (how much performance improves from first to best)
analysis['convergence_rate'] = ((times[0] - min(times)) / times[0]) * 100 if times[0] > 0 else 0
# Find when performance stabilizes (within 5% of minimum)
min_time = min(times)
threshold = min_time * 1.05
for i, t in enumerate(times):
if t <= threshold:
analysis['stabilized_at_run'] = i + 1
break
return analysis
def plot_ascii_graph(original_times: List[float], match_times: List[float], width: int = 60):
"""Create an ASCII graph showing performance over runs."""
if not original_times or not match_times:
return
max_time = max(max(original_times), max(match_times))
min_time = min(min(original_times), min(match_times))
scale = width / (max_time - min_time) if max_time > min_time else 1
print("\n Performance Over Runs (shorter is better)")
print(" " + "─" * (width + 2))
for i in range(len(original_times)):
orig_bar_len = int((original_times[i] - min_time) * scale)
match_bar_len = int((match_times[i] - min_time) * scale)
print(f" Run {i+1:2d}:")
print(f" Orig │{'█' * orig_bar_len} {original_times[i]:.4f}s")
print(f" Match │{'▓' * match_bar_len} {match_times[i]:.4f}s")
if i < len(original_times) - 1:
print(" │")
def verify_correctness():
"""Verify both implementations produce same results."""
print("Verifying correctness...")
all_correct = True
for name, code in TEST_TEMPLATES.items():
tree = ast.parse(code)
original_result = original_implementation(tree)
match_result = match_case_implementation(tree)
if original_result != match_result:
print(f" ❌ {name}: Results differ!")
print(f" Original: {original_result}")
print(f" Match: {match_result}")
all_correct = False
else:
print(f" ✓ {name}: Results match {original_result}")
return all_correct
def main():
"""Main benchmark runner with hot path analysis."""
print("=" * 80)
print("AST Pattern Matching Hot Path Performance Analysis")
print(f"Python version: {sys.version}")
print(f"Timestamp: {datetime.now().isoformat()}")
print("=" * 80)
# Check Python version
if sys.version_info < (3, 10):
print("⚠️ Warning: This benchmark requires Python 3.10+ for match-case syntax")
return
# Configuration
num_launches = 5 # Number of separate process launches to simulate
runs_per_launch = 10 # Number of runs within each launch
iterations_per_run = 50000
print(f"\nConfiguration:")
print(f" • Launches: {num_launches}")
print(f" • Runs per launch: {runs_per_launch}")
print(f" • Iterations per run: {iterations_per_run:,}")
print(f" • Total iterations: {num_launches * runs_per_launch * iterations_per_run:,}")
# First verify correctness
if not verify_correctness():
print("\n⚠️ Correctness check failed! Results may not be comparable.")
return
all_results = []
# Multiple launches to see cold vs hot performance
for launch in range(num_launches):
print("\n" + "=" * 80)
print(f"Launch {launch + 1}/{num_launches}")
print("=" * 80)
launch_results = {}
for test_name, code in TEST_TEMPLATES.items():
print(f"\n📊 Test case: {test_name}")
print(f" Code length: {len(code)} chars")
# Run hot path analysis
original_times, match_times = run_hot_path_analysis(
code, runs_per_launch, iterations_per_run
)
# Analyze convergence
orig_analysis = analyze_convergence(original_times, "Original")
match_analysis = analyze_convergence(match_times, "Match-case")
# Store results
launch_results[test_name] = {
'original_times': original_times,
'match_times': match_times,
'orig_analysis': orig_analysis,
'match_analysis': match_analysis
}
# Print run-by-run comparison
print(f"\n Run-by-run timing (seconds for {iterations_per_run:,} iterations):")
print(f" {'Run':<6} {'Original':<12} {'Match-case':<12} {'Diff %':<10} {'Winner'}")
print(f" {'-'*6} {'-'*12} {'-'*12} {'-'*10} {'-'*10}")
for run_idx in range(runs_per_launch):
orig_t = original_times[run_idx]
match_t = match_times[run_idx]
diff_pct = ((match_t - orig_t) / orig_t) * 100
winner = "Match 🚀" if match_t < orig_t else "Original 🐢"
# Highlight best times
orig_marker = " *" if orig_t == min(original_times) else " "
match_marker = " *" if match_t == min(match_times) else " "
print(f" {run_idx+1:<6} {orig_t:<10.4f}{orig_marker} {match_t:<10.4f}{match_marker} "
f"{diff_pct:>+9.2f}% {winner}")
# Print convergence analysis
print(f"\n Convergence Analysis:")
print(f" • Original: Initial={orig_analysis['initial']:.4f}s, "
f"Best={orig_analysis['min']:.4f}s, "
f"Convergence={orig_analysis['convergence_rate']:.1f}%, "
f"Stable at run #{orig_analysis['stabilized_at_run']}")
print(f" • Match: Initial={match_analysis['initial']:.4f}s, "
f"Best={match_analysis['min']:.4f}s, "
f"Convergence={match_analysis['convergence_rate']:.1f}%, "
f"Stable at run #{match_analysis['stabilized_at_run']}")
# Best performance comparison
best_orig = min(original_times)
best_match = min(match_times)
best_diff = ((best_match - best_orig) / best_orig) * 100
if best_diff < 0:
print(f"\n 🏆 Best Performance: Match-case is {abs(best_diff):.2f}% faster")
else:
print(f"\n 🏆 Best Performance: Original is {best_diff:.2f}% faster")
# Plot ASCII graph for this test case
if launch == 0: # Only plot for first launch to save space
plot_ascii_graph(original_times, match_times, width=40)
all_results.append(launch_results)
# Overall summary across all launches
print("\n" + "=" * 80)
print("OVERALL SUMMARY ACROSS ALL LAUNCHES")
print("=" * 80)
aggregate_results = {}
for test_name in TEST_TEMPLATES.keys():
all_orig_times = []
all_match_times = []
for launch_result in all_results:
all_orig_times.extend(launch_result[test_name]['original_times'])
all_match_times.extend(launch_result[test_name]['match_times'])
aggregate_results[test_name] = {
'orig_mean': statistics.mean(all_orig_times),
'orig_stdev': statistics.stdev(all_orig_times),
'orig_min': min(all_orig_times),
'orig_max': max(all_orig_times),
'match_mean': statistics.mean(all_match_times),
'match_stdev': statistics.stdev(all_match_times),
'match_min': min(all_match_times),
'match_max': max(all_match_times),
}
print("\n📈 Performance Statistics (all times in seconds):")
print(f"{'Test Case':<20} {'Implementation':<15} {'Mean ± StdDev':<20} {'Min':<10} {'Max':<10}")
print("-" * 75)
for test_name, stats in aggregate_results.items():
print(f"{test_name:<20} {'Original':<15} "
f"{stats['orig_mean']:.4f} ± {stats['orig_stdev']:.4f} "
f"{stats['orig_min']:.4f} {stats['orig_max']:.4f}")
print(f"{'':<20} {'Match-case':<15} "
f"{stats['match_mean']:.4f} ± {stats['match_stdev']:.4f} "
f"{stats['match_min']:.4f} {stats['match_max']:.4f}")
diff_pct = ((stats['match_mean'] - stats['orig_mean']) / stats['orig_mean']) * 100
if diff_pct < 0:
verdict = f"✅ Match {abs(diff_pct):.1f}% faster"
else:
verdict = f"⚠️ Match {diff_pct:.1f}% slower"
print(f"{'':<20} {'Verdict:':<15} {verdict}")
print()
# Hot path insights
print("\n" + "=" * 80)
print("🔥 HOT PATH INSIGHTS")
print("=" * 80)
print("\nKey Observations:")
# Calculate average convergence rates
avg_orig_convergence = []
avg_match_convergence = []
for launch_result in all_results:
for test_name, data in launch_result.items():
avg_orig_convergence.append(data['orig_analysis']['convergence_rate'])
avg_match_convergence.append(data['match_analysis']['convergence_rate'])
print(f"\n1. Warm-up Effect:")
print(f" • Original avg convergence: {statistics.mean(avg_orig_convergence):.2f}%")
print(f" • Match-case avg convergence: {statistics.mean(avg_match_convergence):.2f}%")
print(f"\n2. Stability:")
stability_comparison = []
for test_name in TEST_TEMPLATES.keys():
orig_stdev = aggregate_results[test_name]['orig_stdev']
match_stdev = aggregate_results[test_name]['match_stdev']
stability_comparison.append((test_name, orig_stdev, match_stdev))
stability_comparison.sort(key=lambda x: abs(x[1] - x[2]), reverse=True)
print(" Most variable performance differences:")
for name, orig_std, match_std in stability_comparison[:3]:
more_stable = "Original" if orig_std < match_std else "Match-case"
print(f" • {name}: {more_stable} is more stable "
f"(σ_orig={orig_std:.4f}, σ_match={match_std:.4f})")
# Final recommendation
print("\n" + "=" * 80)
print("📋 RECOMMENDATIONS")
print("=" * 80)
total_orig = sum(stats['orig_mean'] for stats in aggregate_results.values())
total_match = sum(stats['match_mean'] for stats in aggregate_results.values())
overall_diff = ((total_match - total_orig) / total_orig) * 100
print(f"\nOverall Performance: ", end="")
if overall_diff < -5:
print(f"Match-case is {abs(overall_diff):.1f}% faster ✅")
print("\n✅ ADOPT: Match-case shows consistent performance improvement")
print(" • Better hot path performance")
print(" • More readable and maintainable code")
print(" • Future Python versions likely to optimize further")
elif overall_diff < 5:
print(f"Difference is {abs(overall_diff):.1f}% {'faster' if overall_diff < 0 else 'slower'} 🤔")
print("\n🤔 NEUTRAL: Performance difference is negligible")
print(" • Decision should be based on code style preferences")
print(" • Match-case offers better readability")
print(" • Consider team familiarity with pattern matching")
else:
print(f"Match-case is {overall_diff:.1f}% slower ⚠️")
print("\n⚠️ CAUTION: Match-case shows performance regression")
print(" • Consider keeping original implementation")
print(" • May improve in future Python versions")
print(" • Test with your specific Python version and workload")
print(f"\nTested on Python {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
print("Note: Results vary by Python version, CPU, and system load")
if __name__ == "__main__":
main()