Files
yaze/scripts/test_runner.py

434 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Advanced test runner with automatic sharding and parallel execution for yaze.
Optimizes test execution time by distributing tests across multiple processes.
"""
import multiprocessing
import json
import subprocess
import time
import argparse
import sys
import os
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, asdict
from concurrent.futures import ProcessPoolExecutor, as_completed
import hashlib
@dataclass
class TestResult:
"""Container for test execution results."""
name: str
status: str # passed, failed, skipped
duration: float
output: str
shard_id: int
@dataclass
class ShardResult:
"""Results from a single test shard."""
shard_id: int
return_code: int
tests_run: int
tests_passed: int
tests_failed: int
duration: float
test_results: List[TestResult]
class TestRunner:
"""Advanced test runner with sharding and parallel execution."""
def __init__(self, test_binary: str, num_shards: int = None,
cache_dir: str = None, verbose: bool = False):
self.test_binary = Path(test_binary).resolve()
if not self.test_binary.exists():
raise FileNotFoundError(f"Test binary not found: {test_binary}")
self.num_shards = num_shards or min(multiprocessing.cpu_count(), 8)
self.cache_dir = Path(cache_dir or Path.home() / ".yaze_test_cache")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.verbose = verbose
self.test_times = self.load_test_times()
def load_test_times(self) -> Dict[str, float]:
"""Load historical test execution times from cache."""
cache_file = self.cache_dir / "test_times.json"
if cache_file.exists():
try:
return json.loads(cache_file.read_text())
except (json.JSONDecodeError, IOError):
return {}
return {}
def save_test_times(self, test_times: Dict[str, float]):
"""Save test execution times to cache."""
cache_file = self.cache_dir / "test_times.json"
# Merge with existing times
existing = self.load_test_times()
for test, time in test_times.items():
# Use exponential moving average for smoothing
if test in existing:
existing[test] = 0.7 * existing[test] + 0.3 * time
else:
existing[test] = time
cache_file.write_text(json.dumps(existing, indent=2))
def discover_tests(self, filter_pattern: str = None) -> List[str]:
"""Discover all tests in the binary."""
cmd = [str(self.test_binary), "--gtest_list_tests"]
if filter_pattern:
cmd.append(f"--gtest_filter={filter_pattern}")
try:
result = subprocess.run(cmd, capture_output=True, text=True,
timeout=30, check=False)
except subprocess.TimeoutExpired:
print("Warning: Test discovery timed out", file=sys.stderr)
return []
if result.returncode != 0:
print(f"Warning: Test discovery failed: {result.stderr}", file=sys.stderr)
return []
# Parse gtest output
tests = []
current_suite = ""
for line in result.stdout.splitlines():
line = line.rstrip()
if not line or line.startswith("Running main()"):
continue
if line and not line.startswith(" "):
# Test suite name
current_suite = line.rstrip(".")
elif line.strip():
# Test case name
test_name = line.strip()
# Remove comments (e.g., " TestName # Comment")
if "#" in test_name:
test_name = test_name.split("#")[0].strip()
if test_name:
tests.append(f"{current_suite}.{test_name}")
if self.verbose:
print(f"Discovered {len(tests)} tests")
return tests
def create_balanced_shards(self, tests: List[str]) -> List[List[str]]:
"""Create balanced shards based on historical execution times."""
if not tests:
return []
# Sort tests by execution time (longest first)
# Use historical times or default estimate
default_time = 0.1 # 100ms default per test
sorted_tests = sorted(
tests,
key=lambda t: self.test_times.get(t, default_time),
reverse=True
)
# Initialize shards
num_shards = min(self.num_shards, len(tests))
shards = [[] for _ in range(num_shards)]
shard_times = [0.0] * num_shards
# Distribute tests using greedy bin packing
for test in sorted_tests:
# Find shard with minimum total time
min_shard_idx = shard_times.index(min(shard_times))
shards[min_shard_idx].append(test)
shard_times[min_shard_idx] += self.test_times.get(test, default_time)
# Remove empty shards
shards = [s for s in shards if s]
if self.verbose:
print(f"Created {len(shards)} shards:")
for i, shard in enumerate(shards):
print(f" Shard {i}: {len(shard)} tests, "
f"estimated {shard_times[i]:.2f}s")
return shards
def run_shard(self, shard_id: int, tests: List[str],
output_dir: Path = None) -> ShardResult:
"""Run a single shard of tests."""
if not tests:
return ShardResult(shard_id, 0, 0, 0, 0, 0.0, [])
filter_str = ":".join(tests)
output_dir = output_dir or self.cache_dir / "results"
output_dir.mkdir(parents=True, exist_ok=True)
# Prepare command
json_output = output_dir / f"shard_{shard_id}_results.json"
xml_output = output_dir / f"shard_{shard_id}_results.xml"
cmd = [
str(self.test_binary),
f"--gtest_filter={filter_str}",
f"--gtest_output=json:{json_output}",
"--gtest_brief=1"
]
# Run tests
start_time = time.time()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=600 # 10 minute timeout per shard
)
duration = time.time() - start_time
except subprocess.TimeoutExpired:
print(f"Shard {shard_id} timed out!", file=sys.stderr)
return ShardResult(shard_id, -1, len(tests), 0, len(tests),
600.0, [])
# Parse results
test_results = []
tests_run = 0
tests_passed = 0
tests_failed = 0
if json_output.exists():
try:
with open(json_output) as f:
data = json.load(f)
for suite in data.get("testsuites", []):
for testcase in suite.get("testsuite", []):
test_name = f"{suite['name']}.{testcase['name']}"
status = "passed" if testcase.get("result") == "COMPLETED" else "failed"
test_duration = float(testcase.get("time", "0").rstrip("s"))
test_results.append(TestResult(
name=test_name,
status=status,
duration=test_duration,
output=testcase.get("output", ""),
shard_id=shard_id
))
tests_run += 1
if status == "passed":
tests_passed += 1
else:
tests_failed += 1
except (json.JSONDecodeError, KeyError, IOError) as e:
print(f"Warning: Failed to parse results for shard {shard_id}: {e}",
file=sys.stderr)
return ShardResult(
shard_id=shard_id,
return_code=result.returncode,
tests_run=tests_run,
tests_passed=tests_passed,
tests_failed=tests_failed,
duration=duration,
test_results=test_results
)
def run_parallel(self, filter_pattern: str = None,
output_dir: str = None) -> Tuple[int, Dict]:
"""Run tests in parallel shards."""
# Discover tests
tests = self.discover_tests(filter_pattern)
if not tests:
print("No tests found to run")
return 0, {}
print(f"Running {len(tests)} tests in up to {self.num_shards} shards...")
# Create shards
shards = self.create_balanced_shards(tests)
output_path = Path(output_dir) if output_dir else self.cache_dir / "results"
# Run shards in parallel
all_results = []
start_time = time.time()
with ProcessPoolExecutor(max_workers=len(shards)) as executor:
# Submit all shards
futures = {
executor.submit(self.run_shard, i, shard, output_path): i
for i, shard in enumerate(shards)
}
# Collect results
for future in as_completed(futures):
shard_id = futures[future]
try:
result = future.result()
all_results.append(result)
if self.verbose:
print(f"Shard {shard_id} completed: "
f"{result.tests_passed}/{result.tests_run} passed "
f"in {result.duration:.2f}s")
except Exception as e:
print(f"Shard {shard_id} failed with exception: {e}",
file=sys.stderr)
total_duration = time.time() - start_time
# Aggregate results
total_tests = sum(r.tests_run for r in all_results)
total_passed = sum(r.tests_passed for r in all_results)
total_failed = sum(r.tests_failed for r in all_results)
max_return_code = max((r.return_code for r in all_results), default=0)
# Update test times cache
new_times = {}
for result in all_results:
for test_result in result.test_results:
new_times[test_result.name] = test_result.duration
self.save_test_times(new_times)
# Generate summary
summary = {
"total_tests": total_tests,
"passed": total_passed,
"failed": total_failed,
"duration": total_duration,
"num_shards": len(shards),
"parallel_efficiency": (sum(r.duration for r in all_results) /
(total_duration * len(shards)) * 100)
if len(shards) > 0 else 0,
"shards": [asdict(r) for r in all_results]
}
# Save summary
summary_file = output_path / "summary.json"
summary_file.write_text(json.dumps(summary, indent=2))
# Print results
print(f"\n{'=' * 60}")
print(f"Test Execution Summary")
print(f"{'=' * 60}")
print(f"Total Tests: {total_tests}")
print(f"Passed: {total_passed} ({total_passed/total_tests*100:.1f}%)")
print(f"Failed: {total_failed}")
print(f"Duration: {total_duration:.2f}s")
print(f"Shards Used: {len(shards)}")
print(f"Efficiency: {summary['parallel_efficiency']:.1f}%")
if total_failed > 0:
print(f"\nFailed Tests:")
for result in all_results:
for test_result in result.test_results:
if test_result.status == "failed":
print(f" - {test_result.name}")
return max_return_code, summary
def run_with_retry(self, filter_pattern: str = None,
max_retries: int = 2) -> int:
"""Run tests with automatic retry for flaky tests."""
failed_tests = set()
attempt = 0
while attempt <= max_retries:
if attempt > 0:
# Only retry failed tests
if not failed_tests:
break
filter_pattern = ":".join(failed_tests)
print(f"\nRetry attempt {attempt} for {len(failed_tests)} failed tests")
return_code, summary = self.run_parallel(filter_pattern)
if return_code == 0:
if attempt > 0:
print(f"All tests passed after {attempt} retries")
return 0
# Collect failed tests for retry
failed_tests.clear()
for shard in summary.get("shards", []):
for test_result in shard.get("test_results", []):
if test_result.get("status") == "failed":
failed_tests.add(test_result.get("name"))
attempt += 1
print(f"Tests still failing after {max_retries} retries")
return return_code
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Advanced test runner with parallel execution"
)
parser.add_argument(
"test_binary",
help="Path to the test binary"
)
parser.add_argument(
"--shards",
type=int,
help="Number of parallel shards (default: CPU count)"
)
parser.add_argument(
"--filter",
help="Test filter pattern (gtest format)"
)
parser.add_argument(
"--output-dir",
help="Directory for test results"
)
parser.add_argument(
"--cache-dir",
help="Directory for test cache (default: ~/.yaze_test_cache)"
)
parser.add_argument(
"--retry",
type=int,
default=0,
help="Number of retries for failed tests"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose output"
)
args = parser.parse_args()
try:
runner = TestRunner(
test_binary=args.test_binary,
num_shards=args.shards,
cache_dir=args.cache_dir,
verbose=args.verbose
)
if args.retry > 0:
return_code = runner.run_with_retry(
filter_pattern=args.filter,
max_retries=args.retry
)
else:
return_code, _ = runner.run_parallel(
filter_pattern=args.filter,
output_dir=args.output_dir
)
sys.exit(return_code)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()