434 lines
14 KiB
Python
434 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Advanced test runner with automatic sharding and parallel execution for yaze.
|
|
Optimizes test execution time by distributing tests across multiple processes.
|
|
"""
|
|
|
|
import multiprocessing
|
|
import json
|
|
import subprocess
|
|
import time
|
|
import argparse
|
|
import sys
|
|
import os
|
|
from pathlib import Path
|
|
from typing import List, Dict, Tuple, Optional
|
|
from dataclasses import dataclass, asdict
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
import hashlib
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
"""Container for test execution results."""
|
|
name: str
|
|
status: str # passed, failed, skipped
|
|
duration: float
|
|
output: str
|
|
shard_id: int
|
|
|
|
@dataclass
|
|
class ShardResult:
|
|
"""Results from a single test shard."""
|
|
shard_id: int
|
|
return_code: int
|
|
tests_run: int
|
|
tests_passed: int
|
|
tests_failed: int
|
|
duration: float
|
|
test_results: List[TestResult]
|
|
|
|
class TestRunner:
|
|
"""Advanced test runner with sharding and parallel execution."""
|
|
|
|
def __init__(self, test_binary: str, num_shards: int = None,
|
|
cache_dir: str = None, verbose: bool = False):
|
|
self.test_binary = Path(test_binary).resolve()
|
|
if not self.test_binary.exists():
|
|
raise FileNotFoundError(f"Test binary not found: {test_binary}")
|
|
|
|
self.num_shards = num_shards or min(multiprocessing.cpu_count(), 8)
|
|
self.cache_dir = Path(cache_dir or Path.home() / ".yaze_test_cache")
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self.verbose = verbose
|
|
self.test_times = self.load_test_times()
|
|
|
|
def load_test_times(self) -> Dict[str, float]:
|
|
"""Load historical test execution times from cache."""
|
|
cache_file = self.cache_dir / "test_times.json"
|
|
if cache_file.exists():
|
|
try:
|
|
return json.loads(cache_file.read_text())
|
|
except (json.JSONDecodeError, IOError):
|
|
return {}
|
|
return {}
|
|
|
|
def save_test_times(self, test_times: Dict[str, float]):
|
|
"""Save test execution times to cache."""
|
|
cache_file = self.cache_dir / "test_times.json"
|
|
|
|
# Merge with existing times
|
|
existing = self.load_test_times()
|
|
for test, time in test_times.items():
|
|
# Use exponential moving average for smoothing
|
|
if test in existing:
|
|
existing[test] = 0.7 * existing[test] + 0.3 * time
|
|
else:
|
|
existing[test] = time
|
|
|
|
cache_file.write_text(json.dumps(existing, indent=2))
|
|
|
|
def discover_tests(self, filter_pattern: str = None) -> List[str]:
|
|
"""Discover all tests in the binary."""
|
|
cmd = [str(self.test_binary), "--gtest_list_tests"]
|
|
if filter_pattern:
|
|
cmd.append(f"--gtest_filter={filter_pattern}")
|
|
|
|
try:
|
|
result = subprocess.run(cmd, capture_output=True, text=True,
|
|
timeout=30, check=False)
|
|
except subprocess.TimeoutExpired:
|
|
print("Warning: Test discovery timed out", file=sys.stderr)
|
|
return []
|
|
|
|
if result.returncode != 0:
|
|
print(f"Warning: Test discovery failed: {result.stderr}", file=sys.stderr)
|
|
return []
|
|
|
|
# Parse gtest output
|
|
tests = []
|
|
current_suite = ""
|
|
for line in result.stdout.splitlines():
|
|
line = line.rstrip()
|
|
if not line or line.startswith("Running main()"):
|
|
continue
|
|
|
|
if line and not line.startswith(" "):
|
|
# Test suite name
|
|
current_suite = line.rstrip(".")
|
|
elif line.strip():
|
|
# Test case name
|
|
test_name = line.strip()
|
|
# Remove comments (e.g., " TestName # Comment")
|
|
if "#" in test_name:
|
|
test_name = test_name.split("#")[0].strip()
|
|
if test_name:
|
|
tests.append(f"{current_suite}.{test_name}")
|
|
|
|
if self.verbose:
|
|
print(f"Discovered {len(tests)} tests")
|
|
|
|
return tests
|
|
|
|
def create_balanced_shards(self, tests: List[str]) -> List[List[str]]:
|
|
"""Create balanced shards based on historical execution times."""
|
|
if not tests:
|
|
return []
|
|
|
|
# Sort tests by execution time (longest first)
|
|
# Use historical times or default estimate
|
|
default_time = 0.1 # 100ms default per test
|
|
sorted_tests = sorted(
|
|
tests,
|
|
key=lambda t: self.test_times.get(t, default_time),
|
|
reverse=True
|
|
)
|
|
|
|
# Initialize shards
|
|
num_shards = min(self.num_shards, len(tests))
|
|
shards = [[] for _ in range(num_shards)]
|
|
shard_times = [0.0] * num_shards
|
|
|
|
# Distribute tests using greedy bin packing
|
|
for test in sorted_tests:
|
|
# Find shard with minimum total time
|
|
min_shard_idx = shard_times.index(min(shard_times))
|
|
shards[min_shard_idx].append(test)
|
|
shard_times[min_shard_idx] += self.test_times.get(test, default_time)
|
|
|
|
# Remove empty shards
|
|
shards = [s for s in shards if s]
|
|
|
|
if self.verbose:
|
|
print(f"Created {len(shards)} shards:")
|
|
for i, shard in enumerate(shards):
|
|
print(f" Shard {i}: {len(shard)} tests, "
|
|
f"estimated {shard_times[i]:.2f}s")
|
|
|
|
return shards
|
|
|
|
def run_shard(self, shard_id: int, tests: List[str],
|
|
output_dir: Path = None) -> ShardResult:
|
|
"""Run a single shard of tests."""
|
|
if not tests:
|
|
return ShardResult(shard_id, 0, 0, 0, 0, 0.0, [])
|
|
|
|
filter_str = ":".join(tests)
|
|
output_dir = output_dir or self.cache_dir / "results"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Prepare command
|
|
json_output = output_dir / f"shard_{shard_id}_results.json"
|
|
xml_output = output_dir / f"shard_{shard_id}_results.xml"
|
|
|
|
cmd = [
|
|
str(self.test_binary),
|
|
f"--gtest_filter={filter_str}",
|
|
f"--gtest_output=json:{json_output}",
|
|
"--gtest_brief=1"
|
|
]
|
|
|
|
# Run tests
|
|
start_time = time.time()
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=600 # 10 minute timeout per shard
|
|
)
|
|
duration = time.time() - start_time
|
|
except subprocess.TimeoutExpired:
|
|
print(f"Shard {shard_id} timed out!", file=sys.stderr)
|
|
return ShardResult(shard_id, -1, len(tests), 0, len(tests),
|
|
600.0, [])
|
|
|
|
# Parse results
|
|
test_results = []
|
|
tests_run = 0
|
|
tests_passed = 0
|
|
tests_failed = 0
|
|
|
|
if json_output.exists():
|
|
try:
|
|
with open(json_output) as f:
|
|
data = json.load(f)
|
|
|
|
for suite in data.get("testsuites", []):
|
|
for testcase in suite.get("testsuite", []):
|
|
test_name = f"{suite['name']}.{testcase['name']}"
|
|
status = "passed" if testcase.get("result") == "COMPLETED" else "failed"
|
|
test_duration = float(testcase.get("time", "0").rstrip("s"))
|
|
|
|
test_results.append(TestResult(
|
|
name=test_name,
|
|
status=status,
|
|
duration=test_duration,
|
|
output=testcase.get("output", ""),
|
|
shard_id=shard_id
|
|
))
|
|
|
|
tests_run += 1
|
|
if status == "passed":
|
|
tests_passed += 1
|
|
else:
|
|
tests_failed += 1
|
|
|
|
except (json.JSONDecodeError, KeyError, IOError) as e:
|
|
print(f"Warning: Failed to parse results for shard {shard_id}: {e}",
|
|
file=sys.stderr)
|
|
|
|
return ShardResult(
|
|
shard_id=shard_id,
|
|
return_code=result.returncode,
|
|
tests_run=tests_run,
|
|
tests_passed=tests_passed,
|
|
tests_failed=tests_failed,
|
|
duration=duration,
|
|
test_results=test_results
|
|
)
|
|
|
|
def run_parallel(self, filter_pattern: str = None,
|
|
output_dir: str = None) -> Tuple[int, Dict]:
|
|
"""Run tests in parallel shards."""
|
|
# Discover tests
|
|
tests = self.discover_tests(filter_pattern)
|
|
if not tests:
|
|
print("No tests found to run")
|
|
return 0, {}
|
|
|
|
print(f"Running {len(tests)} tests in up to {self.num_shards} shards...")
|
|
|
|
# Create shards
|
|
shards = self.create_balanced_shards(tests)
|
|
output_path = Path(output_dir) if output_dir else self.cache_dir / "results"
|
|
|
|
# Run shards in parallel
|
|
all_results = []
|
|
start_time = time.time()
|
|
|
|
with ProcessPoolExecutor(max_workers=len(shards)) as executor:
|
|
# Submit all shards
|
|
futures = {
|
|
executor.submit(self.run_shard, i, shard, output_path): i
|
|
for i, shard in enumerate(shards)
|
|
}
|
|
|
|
# Collect results
|
|
for future in as_completed(futures):
|
|
shard_id = futures[future]
|
|
try:
|
|
result = future.result()
|
|
all_results.append(result)
|
|
|
|
if self.verbose:
|
|
print(f"Shard {shard_id} completed: "
|
|
f"{result.tests_passed}/{result.tests_run} passed "
|
|
f"in {result.duration:.2f}s")
|
|
except Exception as e:
|
|
print(f"Shard {shard_id} failed with exception: {e}",
|
|
file=sys.stderr)
|
|
|
|
total_duration = time.time() - start_time
|
|
|
|
# Aggregate results
|
|
total_tests = sum(r.tests_run for r in all_results)
|
|
total_passed = sum(r.tests_passed for r in all_results)
|
|
total_failed = sum(r.tests_failed for r in all_results)
|
|
max_return_code = max((r.return_code for r in all_results), default=0)
|
|
|
|
# Update test times cache
|
|
new_times = {}
|
|
for result in all_results:
|
|
for test_result in result.test_results:
|
|
new_times[test_result.name] = test_result.duration
|
|
self.save_test_times(new_times)
|
|
|
|
# Generate summary
|
|
summary = {
|
|
"total_tests": total_tests,
|
|
"passed": total_passed,
|
|
"failed": total_failed,
|
|
"duration": total_duration,
|
|
"num_shards": len(shards),
|
|
"parallel_efficiency": (sum(r.duration for r in all_results) /
|
|
(total_duration * len(shards)) * 100)
|
|
if len(shards) > 0 else 0,
|
|
"shards": [asdict(r) for r in all_results]
|
|
}
|
|
|
|
# Save summary
|
|
summary_file = output_path / "summary.json"
|
|
summary_file.write_text(json.dumps(summary, indent=2))
|
|
|
|
# Print results
|
|
print(f"\n{'=' * 60}")
|
|
print(f"Test Execution Summary")
|
|
print(f"{'=' * 60}")
|
|
print(f"Total Tests: {total_tests}")
|
|
print(f"Passed: {total_passed} ({total_passed/total_tests*100:.1f}%)")
|
|
print(f"Failed: {total_failed}")
|
|
print(f"Duration: {total_duration:.2f}s")
|
|
print(f"Shards Used: {len(shards)}")
|
|
print(f"Efficiency: {summary['parallel_efficiency']:.1f}%")
|
|
|
|
if total_failed > 0:
|
|
print(f"\nFailed Tests:")
|
|
for result in all_results:
|
|
for test_result in result.test_results:
|
|
if test_result.status == "failed":
|
|
print(f" - {test_result.name}")
|
|
|
|
return max_return_code, summary
|
|
|
|
def run_with_retry(self, filter_pattern: str = None,
|
|
max_retries: int = 2) -> int:
|
|
"""Run tests with automatic retry for flaky tests."""
|
|
failed_tests = set()
|
|
attempt = 0
|
|
|
|
while attempt <= max_retries:
|
|
if attempt > 0:
|
|
# Only retry failed tests
|
|
if not failed_tests:
|
|
break
|
|
filter_pattern = ":".join(failed_tests)
|
|
print(f"\nRetry attempt {attempt} for {len(failed_tests)} failed tests")
|
|
|
|
return_code, summary = self.run_parallel(filter_pattern)
|
|
|
|
if return_code == 0:
|
|
if attempt > 0:
|
|
print(f"All tests passed after {attempt} retries")
|
|
return 0
|
|
|
|
# Collect failed tests for retry
|
|
failed_tests.clear()
|
|
for shard in summary.get("shards", []):
|
|
for test_result in shard.get("test_results", []):
|
|
if test_result.get("status") == "failed":
|
|
failed_tests.add(test_result.get("name"))
|
|
|
|
attempt += 1
|
|
|
|
print(f"Tests still failing after {max_retries} retries")
|
|
return return_code
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Advanced test runner with parallel execution"
|
|
)
|
|
parser.add_argument(
|
|
"test_binary",
|
|
help="Path to the test binary"
|
|
)
|
|
parser.add_argument(
|
|
"--shards",
|
|
type=int,
|
|
help="Number of parallel shards (default: CPU count)"
|
|
)
|
|
parser.add_argument(
|
|
"--filter",
|
|
help="Test filter pattern (gtest format)"
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
help="Directory for test results"
|
|
)
|
|
parser.add_argument(
|
|
"--cache-dir",
|
|
help="Directory for test cache (default: ~/.yaze_test_cache)"
|
|
)
|
|
parser.add_argument(
|
|
"--retry",
|
|
type=int,
|
|
default=0,
|
|
help="Number of retries for failed tests"
|
|
)
|
|
parser.add_argument(
|
|
"--verbose", "-v",
|
|
action="store_true",
|
|
help="Enable verbose output"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
runner = TestRunner(
|
|
test_binary=args.test_binary,
|
|
num_shards=args.shards,
|
|
cache_dir=args.cache_dir,
|
|
verbose=args.verbose
|
|
)
|
|
|
|
if args.retry > 0:
|
|
return_code = runner.run_with_retry(
|
|
filter_pattern=args.filter,
|
|
max_retries=args.retry
|
|
)
|
|
else:
|
|
return_code, _ = runner.run_parallel(
|
|
filter_pattern=args.filter,
|
|
output_dir=args.output_dir
|
|
)
|
|
|
|
sys.exit(return_code)
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |