diff --git a/src/afs_scawful/cli.py b/src/afs_scawful/cli.py index 184ccd3..d14dcff 100644 --- a/src/afs_scawful/cli.py +++ b/src/afs_scawful/cli.py @@ -3,12 +3,16 @@ from __future__ import annotations import argparse +import asyncio +import json from pathlib import Path from typing import Iterable -from .registry import index_datasets, build_dataset_registry, write_dataset_registry +from .registry import build_dataset_registry, index_datasets, write_dataset_registry from .resource_index import ResourceIndexer from .paths import resolve_datasets_root, resolve_index_root +from .training import TrainingSample +from .validators import default_validators def _datasets_index_command(args: argparse.Namespace) -> int: @@ -43,6 +47,45 @@ def _resources_index_command(args: argparse.Namespace) -> int: return 0 +async def _run_validators(sample: TrainingSample, validators) -> list[tuple[str, object]]: + results: list[tuple[str, object]] = [] + for validator in validators: + if validator.can_validate(sample): + result = await validator.validate(sample) + results.append((validator.name, result)) + return results + + +def _validators_list_command(args: argparse.Namespace) -> int: + validators = default_validators() + for validator in validators: + print(f"{validator.name}\t{validator.domain}") + return 0 + + +def _validators_run_command(args: argparse.Namespace) -> int: + sample_path = Path(args.sample).expanduser().resolve() + payload = json.loads(sample_path.read_text(encoding="utf-8")) + sample = TrainingSample.from_dict(payload) + + validators = default_validators() + if args.name: + validators = [v for v in validators if v.name in args.name] + + results = asyncio.run(_run_validators(sample, validators)) + if not results: + print("(no validators)") + return 1 + + overall_ok = True + for name, result in results: + status = "ok" if result.valid else "fail" + if not result.valid: + overall_ok = False + print(f"{name}\t{status}\t{result.score:.2f}") + return 0 if overall_ok else 1 + + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(prog="afs_scawful") subparsers = parser.add_subparsers(dest="command") @@ -77,6 +120,21 @@ def build_parser() -> argparse.ArgumentParser: resources_index.add_argument("--output", help="Output index path.") resources_index.set_defaults(func=_resources_index_command) + validators_parser = subparsers.add_parser("validators", help="Validation tools.") + validators_sub = validators_parser.add_subparsers(dest="validators_command") + + validators_list = validators_sub.add_parser("list", help="List validators.") + validators_list.set_defaults(func=_validators_list_command) + + validators_run = validators_sub.add_parser("run", help="Validate a sample JSON.") + validators_run.add_argument("sample", help="Path to sample JSON.") + validators_run.add_argument( + "--name", + action="append", + help="Validator name to run (repeatable).", + ) + validators_run.set_defaults(func=_validators_run_command) + return parser @@ -92,6 +150,9 @@ def main(argv: Iterable[str] | None = None) -> int: if args.command == "resources" and not getattr(args, "resources_command", None): parser.print_help() return 1 + if args.command == "validators" and not getattr(args, "validators_command", None): + parser.print_help() + return 1 return args.func(args) diff --git a/src/afs_scawful/training.py b/src/afs_scawful/training.py new file mode 100644 index 0000000..5a6db96 --- /dev/null +++ b/src/afs_scawful/training.py @@ -0,0 +1,73 @@ +"""Training sample data models for AFS Scawful.""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + + +@dataclass +class TrainingSample: + instruction: str + input: str + output: str + domain: str + source: str = "" + sample_id: str = "" + timestamp: str = "" + metadata: dict[str, Any] = field(default_factory=dict) + kg_entities: list[str] = field(default_factory=list) + kg_validated: bool = False + + def __post_init__(self) -> None: + if not self.sample_id: + self.sample_id = str(uuid.uuid4()) + if not self.timestamp: + self.timestamp = datetime.now().isoformat() + + def to_dict(self) -> dict[str, Any]: + return { + "instruction": self.instruction, + "input": self.input, + "output": self.output, + "domain": self.domain, + "source": self.source, + "sample_id": self.sample_id, + "timestamp": self.timestamp, + "metadata": self.metadata, + "kg_entities": self.kg_entities, + "kg_validated": self.kg_validated, + } + + def to_jsonl_entry(self) -> str: + payload = { + "instruction": self.instruction, + "output": self.output, + } + if self.input: + payload["input"] = self.input + payload["_metadata"] = { + "sample_id": self.sample_id, + "domain": self.domain, + "source": self.source, + "timestamp": self.timestamp, + } + return json.dumps(payload) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "TrainingSample": + return cls( + instruction=data.get("instruction", ""), + input=data.get("input", ""), + output=data.get("output", ""), + domain=data.get("domain", ""), + source=data.get("source", ""), + sample_id=data.get("sample_id", ""), + timestamp=data.get("timestamp", ""), + metadata=data.get("metadata", {}) or {}, + kg_entities=data.get("kg_entities", []) or [], + kg_validated=bool(data.get("kg_validated", False)), + ) diff --git a/src/afs_scawful/validators/__init__.py b/src/afs_scawful/validators/__init__.py new file mode 100644 index 0000000..e0f4c42 --- /dev/null +++ b/src/afs_scawful/validators/__init__.py @@ -0,0 +1,27 @@ +"""Validator registry for AFS Scawful.""" + +from .asar_validator import AsarValidator +from .asm_validator import AsmValidator +from .base import CompositeValidator, ValidationResult, Validator +from .cpp_validator import CppValidator +from .kg_validator import KGValidator + +__all__ = [ + "AsarValidator", + "AsmValidator", + "CppValidator", + "CompositeValidator", + "KGValidator", + "ValidationResult", + "Validator", + "default_validators", +] + + +def default_validators() -> list[Validator]: + return [ + AsmValidator(), + AsarValidator(), + CppValidator(), + KGValidator(), + ] diff --git a/src/afs_scawful/validators/asar_validator.py b/src/afs_scawful/validators/asar_validator.py new file mode 100644 index 0000000..80ba599 --- /dev/null +++ b/src/afs_scawful/validators/asar_validator.py @@ -0,0 +1,127 @@ +"""Asar Validator for verifying 65816 assembly code. + +Uses the actual 'asar' binary to assemble code snippets against a dummy ROM. +This provides 100% accurate syntax and label validation. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import shutil +import tempfile +from pathlib import Path +from typing import Optional + +from ..training import TrainingSample +from .base import ValidationResult, Validator + +logger = logging.getLogger(__name__) + +def _resolve_env_path(env_var: str) -> Path | None: + value = os.environ.get(env_var) + if not value: + return None + return Path(value).expanduser().resolve() + + +def _default_asar_path() -> Path: + env = _resolve_env_path("AFS_ASAR_PATH") + if env: + return env + found = shutil.which("asar") + if found: + return Path(found) + return Path("asar") + + +def _default_rom_path() -> Path: + env = _resolve_env_path("AFS_ASAR_ROM") + if env: + return env + candidate = Path.home() / "src" / "training" / "roms" / "dummy.sfc" + if candidate.exists(): + return candidate + return Path.home() / ".context" / "training" / "dummy.sfc" + + +class AsarValidator(Validator): + """Validates assembly code by running it through Asar.""" + + def __init__(self, asar_path: Path | None = None, rom_path: Path | None = None): + super().__init__("AsarValidator", "asm") + self.asar_path = asar_path or _default_asar_path() + self.rom_path = rom_path or _default_rom_path() + + if not self.asar_path.exists(): + logger.warning("Asar binary not found at %s", self.asar_path) + if not self.rom_path.exists(): + logger.warning("Dummy ROM not found at %s", self.rom_path) + + async def validate(self, sample: TrainingSample) -> ValidationResult: + """Run asar on the sample output code.""" + if not self.asar_path.exists() or not self.rom_path.exists(): + return ValidationResult( + valid=True, + score=0.5, + warnings=["Asar validator skipped: binary or ROM missing"], + ) + + # Extract code (simple heuristic: look for code blocks or use full output) + code = self._extract_code(sample.output) + if not code: + return ValidationResult(valid=False, score=0.0, errors=["No code found"]) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + source_file = tmp_path / "test.asm" + rom_file = tmp_path / "test.sfc" + + # Copy dummy ROM to temp (to avoid modifying the original) + shutil.copy(self.rom_path, rom_file) + + # Wrap code in a safe patch structure + # We assume the code is a snippet, so we hook it into free space + wrapped_code = ( + "lorom\n" + "org $008000\n" # Hook into start of ROM + f"{code}\n" + ) + + source_file.write_text(wrapped_code) + + # Run asar + proc = await asyncio.create_subprocess_exec( + str(self.asar_path), + str(source_file), + str(rom_file), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await proc.communicate() + + if proc.returncode == 0: + return ValidationResult(valid=True, score=1.0) + else: + error_msg = stderr.decode() + stdout.decode() + # Clean up error message + lines = [l for l in error_msg.split('\n') if "error:" in l.lower()] + return ValidationResult( + valid=False, + score=0.0, + errors=lines[:3] or ["Asar failed to assemble"], + ) + + def _extract_code(self, text: str) -> str: + """Extract ASM code from markdown block or raw text.""" + if "```asm" in text: + parts = text.split("```asm") + if len(parts) > 1: + return parts[1].split("```")[0].strip() + if "```" in text: + parts = text.split("```") + if len(parts) > 1: + return parts[1].strip() + return text # Assume raw code if no blocks diff --git a/src/afs_scawful/validators/asm_validator.py b/src/afs_scawful/validators/asm_validator.py new file mode 100644 index 0000000..385389a --- /dev/null +++ b/src/afs_scawful/validators/asm_validator.py @@ -0,0 +1,342 @@ +"""ASM Validator for 65816 assembly training samples. + +Validates: +- Instruction mnemonics +- Addressing modes +- Register usage +- Memory addressing patterns +- SNES-specific constructs +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + +from ..training import TrainingSample +from .base import ValidationResult, Validator + + +@dataclass +class InstructionInfo: + """Information about a 65816 instruction.""" + + mnemonic: str + addressing_modes: list[str] + description: str + + +class AsmValidator(Validator): + """Validator for 65816 assembly code in training samples.""" + + # Valid 65816 instruction mnemonics + VALID_MNEMONICS = { + # Load/Store + "LDA", "LDX", "LDY", "STA", "STX", "STY", "STZ", + # Transfer + "TAX", "TAY", "TXA", "TYA", "TXS", "TSX", "TCD", "TDC", "TCS", "TSC", "TXY", "TYX", + # Stack + "PHA", "PHP", "PHX", "PHY", "PHB", "PHD", "PHK", + "PLA", "PLP", "PLX", "PLY", "PLB", "PLD", + "PEA", "PEI", "PER", + # Arithmetic + "ADC", "SBC", "INC", "INX", "INY", "DEC", "DEX", "DEY", + # Comparison + "CMP", "CPX", "CPY", + # Logical + "AND", "ORA", "EOR", "BIT", + # Shift/Rotate + "ASL", "LSR", "ROL", "ROR", + # Branch + "BCC", "BCS", "BEQ", "BMI", "BNE", "BPL", "BVC", "BVS", "BRA", "BRL", + # Jump + "JMP", "JML", "JSR", "JSL", "RTS", "RTL", "RTI", + # Flags + "CLC", "CLD", "CLI", "CLV", "SEC", "SED", "SEI", + "REP", "SEP", + # Processor + "NOP", "WDM", "STP", "WAI", "XBA", "XCE", + # Block Move + "MVP", "MVN", + # Misc + "BRK", "COP", "WDM", + # 65C816 specific + "TRB", "TSB", + } + + # Valid addressing mode patterns + ADDRESSING_PATTERNS = { + "immediate_8": r"#\$[0-9A-Fa-f]{1,2}", # #$XX + "immediate_16": r"#\$[0-9A-Fa-f]{3,4}", # #$XXXX + "immediate_symbol": r"#[A-Za-z_]\w*", # #SYMBOL + "direct_page": r"\$[0-9A-Fa-f]{1,2}(?!\w)", # $XX (not followed by more hex) + "absolute": r"\$[0-9A-Fa-f]{4}(?!\w)", # $XXXX + "long": r"\$[0-9A-Fa-f]{6}", # $XXXXXX + "indexed_x": r",\s*[Xx]", # ,X + "indexed_y": r",\s*[Yy]", # ,Y + "indirect": r"\([^)]+\)", # (...) + "stack_relative": r"\$[0-9A-Fa-f]{1,2},\s*[Ss]", # $XX,S + "accumulator": r"[Aa](?:\s|$)", # A + "label": r"[A-Za-z_]\w*", # Labels + } + + # SNES-specific registers and addresses + SNES_REGISTERS = { + # PPU Registers + "INIDISP", "OBSEL", "OAMADDL", "OAMADDH", "OAMDATA", + "BGMODE", "MOSAIC", "BG1SC", "BG2SC", "BG3SC", "BG4SC", + "BG12NBA", "BG34NBA", "BG1HOFS", "BG1VOFS", "BG2HOFS", "BG2VOFS", + "BG3HOFS", "BG3VOFS", "BG4HOFS", "BG4VOFS", + "VMAIN", "VMADDL", "VMADDH", "VMDATAL", "VMDATAH", + "M7SEL", "M7A", "M7B", "M7C", "M7D", "M7X", "M7Y", + "CGADD", "CGDATA", "W12SEL", "W34SEL", "WOBJSEL", + "WH0", "WH1", "WH2", "WH3", "WBGLOG", "WOBJLOG", + "TM", "TS", "TMW", "TSW", "CGWSEL", "CGADSUB", + "COLDATA", "SETINI", + # APU Registers + "APUIO0", "APUIO1", "APUIO2", "APUIO3", + # DMA Registers + "MDMAEN", "HDMAEN", "MEMSEL", + # CPU Registers + "NMITIMEN", "WRIO", "WRMPYA", "WRMPYB", "WRDIVL", "WRDIVH", + "WRDIVB", "HTIMEL", "HTIMEH", "VTIMEL", "VTIMEH", + "RDNMI", "TIMEUP", "HVBJOY", "RDIO", "RDDIVL", "RDDIVH", + "RDMPYL", "RDMPYH", "JOY1L", "JOY1H", "JOY2L", "JOY2H", + "JOY3L", "JOY3H", "JOY4L", "JOY4H", + } + + # Common ALTTP-specific labels + ALTTP_LABELS = { + "Module", "Submodule", "Link", "Player", "Sprite", + "WRAM", "SRAM", "VRAM", "OAM", "CGRAM", + } + + VALID_DOMAINS = {"asm", "hack_curated"} + + def __init__(self, strict: bool = False): + """Initialize ASM validator. + + Args: + strict: If True, apply stricter validation rules + """ + super().__init__("AsmValidator", "asm") + self.strict = strict + + def can_validate(self, sample: TrainingSample) -> bool: + """Allow ASM validation for curated hack samples too.""" + return sample.domain in self.VALID_DOMAINS or sample.domain.startswith("asm") + + async def validate(self, sample: TrainingSample) -> ValidationResult: + """Validate 65816 assembly in the sample output.""" + errors: list[str] = [] + warnings: list[str] = [] + details: dict = { + "instructions_found": 0, + "valid_instructions": 0, + "invalid_instructions": [], + "snes_registers_used": [], + "addressing_modes": [], + } + + # Extract code from output + code = sample.output + + # Parse instructions + instructions = self._extract_instructions(code) + details["instructions_found"] = len(instructions) + + if len(instructions) == 0: + warnings.append("No assembly instructions found in output") + return ValidationResult( + valid=True, + score=0.5, + warnings=warnings, + details=details, + ) + + # Validate each instruction + for line_num, instr in instructions: + result = self._validate_instruction(instr) + if result.valid: + details["valid_instructions"] += 1 + if result.addressing_mode: + details["addressing_modes"].append(result.addressing_mode) + else: + details["invalid_instructions"].append({ + "line": line_num, + "instruction": instr, + "error": result.error, + }) + if self.strict: + errors.append(f"Line {line_num}: {result.error}") + else: + warnings.append(f"Line {line_num}: {result.error}") + + # Check for SNES registers + for reg in self.SNES_REGISTERS: + if reg in code: + details["snes_registers_used"].append(reg) + + # Calculate score + if details["instructions_found"] > 0: + score = details["valid_instructions"] / details["instructions_found"] + else: + score = 0.5 + + # Boost score if SNES-specific content found + if details["snes_registers_used"]: + score = min(1.0, score + 0.1) + + return ValidationResult( + valid=len(errors) == 0, + score=score, + errors=errors, + warnings=warnings, + details=details, + ) + + def _extract_instructions(self, code: str) -> list[tuple[int, str]]: + """Extract assembly instructions from code. + + Returns: + List of (line_number, instruction) tuples + """ + instructions = [] + lines = code.split("\n") + + for i, line in enumerate(lines, 1): + # Remove comments + if ";" in line: + line = line[:line.index(";")] + + # Remove labels (lines ending with :) + if ":" in line: + # Check if it's a label definition + parts = line.split(":") + if len(parts) > 1: + line = parts[-1] + + # Remove address prefixes like #_008000: + line = re.sub(r"#_[0-9A-Fa-f]+:\s*", "", line) + + line = line.strip() + + if not line: + continue + + # Check if line starts with a valid mnemonic + parts = line.split() + if parts: + mnemonic = parts[0].upper() + if mnemonic in self.VALID_MNEMONICS: + instructions.append((i, line)) + elif re.match(r"[A-Za-z]{2,4}", mnemonic): + # Might be an instruction-like thing + instructions.append((i, line)) + + return instructions + + def _validate_instruction(self, instruction: str) -> "_InstructionValidation": + """Validate a single instruction.""" + parts = instruction.split(None, 1) + if not parts: + return _InstructionValidation(False, "Empty instruction") + + mnemonic = parts[0].upper() + operand = parts[1] if len(parts) > 1 else "" + + # Check mnemonic + if mnemonic not in self.VALID_MNEMONICS: + # Check if it's close to a valid mnemonic (typo detection) + close_matches = [m for m in self.VALID_MNEMONICS + if self._levenshtein_distance(mnemonic, m) <= 1] + if close_matches: + return _InstructionValidation( + False, + f"Unknown mnemonic '{mnemonic}' (did you mean {close_matches[0]}?)" + ) + return _InstructionValidation(False, f"Unknown mnemonic '{mnemonic}'") + + # Validate operand if present + addressing_mode = None + if operand: + addressing_mode = self._detect_addressing_mode(operand) + + return _InstructionValidation(True, None, addressing_mode) + + def _detect_addressing_mode(self, operand: str) -> Optional[str]: + """Detect the addressing mode from the operand.""" + operand = operand.strip() + + # Check patterns in order of specificity + if re.match(r"#", operand): + if re.search(r"#\$[0-9A-Fa-f]{3,4}", operand): + return "immediate_16" + elif re.search(r"#\$[0-9A-Fa-f]{1,2}", operand): + return "immediate_8" + else: + return "immediate_symbol" + + if re.search(r",\s*[Ss]", operand): + return "stack_relative" + + if re.match(r"\([^)]+\)", operand): + if ",X" in operand.upper(): + return "indexed_indirect_x" + elif ",Y" in operand.upper(): + return "indirect_indexed_y" + else: + return "indirect" + + if re.search(r",\s*[Xx]", operand): + return "indexed_x" + + if re.search(r",\s*[Yy]", operand): + return "indexed_y" + + if re.match(r"\$[0-9A-Fa-f]{6}", operand): + return "long" + + if re.match(r"\$[0-9A-Fa-f]{4}", operand): + return "absolute" + + if re.match(r"\$[0-9A-Fa-f]{1,2}(?!\w)", operand): + return "direct_page" + + if re.match(r"[Aa]$", operand): + return "accumulator" + + if re.match(r"[A-Za-z_]\w*", operand): + return "label" + + return None + + def _levenshtein_distance(self, s1: str, s2: str) -> int: + """Calculate Levenshtein distance between two strings.""" + if len(s1) < len(s2): + return self._levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = range(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + +@dataclass +class _InstructionValidation: + """Internal result of validating a single instruction.""" + + valid: bool + error: Optional[str] = None + addressing_mode: Optional[str] = None diff --git a/src/afs_scawful/validators/base.py b/src/afs_scawful/validators/base.py new file mode 100644 index 0000000..91f4367 --- /dev/null +++ b/src/afs_scawful/validators/base.py @@ -0,0 +1,90 @@ +"""Base validator interfaces for AFS Scawful.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from ..training import TrainingSample + + +@dataclass +class ValidationResult: + valid: bool + score: float + errors: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + details: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "valid": self.valid, + "score": self.score, + "errors": list(self.errors), + "warnings": list(self.warnings), + "details": dict(self.details), + } + + +class Validator(ABC): + def __init__(self, name: str, domain: str) -> None: + self.name = name + self.domain = domain + + @abstractmethod + async def validate(self, sample: TrainingSample) -> ValidationResult: + raise NotImplementedError + + def can_validate(self, sample: TrainingSample) -> bool: + return sample.domain == self.domain + + async def validate_batch(self, samples: list[TrainingSample]) -> list[ValidationResult]: + results: list[ValidationResult] = [] + for sample in samples: + if self.can_validate(sample): + results.append(await self.validate(sample)) + else: + results.append( + ValidationResult( + valid=True, + score=1.0, + warnings=[f"{self.name} skipped: domain mismatch"], + ) + ) + return results + + +class CompositeValidator(Validator): + def __init__(self, validators: list[Validator]) -> None: + super().__init__("CompositeValidator", "all") + self.validators = validators + + def can_validate(self, sample: TrainingSample) -> bool: + return any(validator.can_validate(sample) for validator in self.validators) + + async def validate(self, sample: TrainingSample) -> ValidationResult: + applicable = [v for v in self.validators if v.can_validate(sample)] + if not applicable: + return ValidationResult(valid=True, score=1.0, warnings=["No applicable validators"]) + + errors: list[str] = [] + warnings: list[str] = [] + details: dict[str, Any] = {} + scores: list[float] = [] + + for validator in applicable: + result = await validator.validate(sample) + errors.extend(result.errors) + warnings.extend(result.warnings) + details[validator.name] = result.to_dict() + scores.append(result.score) + + score = sum(scores) / len(scores) if scores else 1.0 + return ValidationResult( + valid=len(errors) == 0, + score=score, + errors=errors, + warnings=warnings, + details=details, + ) diff --git a/src/afs_scawful/validators/cpp_validator.py b/src/afs_scawful/validators/cpp_validator.py new file mode 100644 index 0000000..a6577dc --- /dev/null +++ b/src/afs_scawful/validators/cpp_validator.py @@ -0,0 +1,340 @@ +"""C++ Validator for training samples. + +Validates: +- Basic syntax checks (brackets, braces, semicolons) +- Keyword usage +- Common patterns +- Optional: Compile check with clang (if available) +""" + +from __future__ import annotations + +import asyncio +import re +import shutil +import tempfile +from pathlib import Path +from typing import Optional + +from ..training import TrainingSample +from .base import ValidationResult, Validator + + +class CppValidator(Validator): + """Validator for C++ code in training samples.""" + + # C++ keywords + KEYWORDS = { + # Storage class + "auto", "register", "static", "extern", "mutable", "thread_local", + # Type specifiers + "void", "bool", "char", "short", "int", "long", "float", "double", + "signed", "unsigned", "wchar_t", "char8_t", "char16_t", "char32_t", + # Type qualifiers + "const", "volatile", "constexpr", "consteval", "constinit", + # Control flow + "if", "else", "switch", "case", "default", "while", "do", "for", + "break", "continue", "return", "goto", + # Declarations + "class", "struct", "union", "enum", "typedef", "using", "namespace", + "template", "typename", "concept", "requires", + # Access specifiers + "public", "private", "protected", + # Other keywords + "virtual", "override", "final", "explicit", "inline", "friend", + "operator", "sizeof", "alignof", "decltype", "typeid", + "new", "delete", "this", "nullptr", "true", "false", + "try", "catch", "throw", "noexcept", + "static_assert", "static_cast", "dynamic_cast", "const_cast", "reinterpret_cast", + "co_await", "co_return", "co_yield", + # Modules (C++20) + "module", "import", "export", + } + + # Common C++ standard library types + STD_TYPES = { + "string", "vector", "map", "unordered_map", "set", "unordered_set", + "list", "deque", "array", "pair", "tuple", "optional", "variant", + "shared_ptr", "unique_ptr", "weak_ptr", "function", "any", + "thread", "mutex", "lock_guard", "unique_lock", "condition_variable", + "future", "promise", "async", "atomic", + "ifstream", "ofstream", "fstream", "stringstream", "ostringstream", + "iostream", "cin", "cout", "cerr", "endl", + "size_t", "ptrdiff_t", "nullptr_t", "byte", + "int8_t", "int16_t", "int32_t", "int64_t", + "uint8_t", "uint16_t", "uint32_t", "uint64_t", + } + + def __init__( + self, + check_compile: bool = False, + compiler: str = "clang++", + strict: bool = False, + ): + """Initialize C++ validator. + + Args: + check_compile: If True, attempt to compile the code + compiler: Compiler to use for compile checks + strict: If True, apply stricter validation + """ + super().__init__("CppValidator", "cpp") + self.check_compile = check_compile + self.compiler = compiler + self.strict = strict + + # Check if compiler is available + self._compiler_available = shutil.which(compiler) is not None + + async def validate(self, sample: TrainingSample) -> ValidationResult: + """Validate C++ code in the sample output.""" + errors: list[str] = [] + warnings: list[str] = [] + details: dict = { + "syntax_issues": [], + "keywords_found": [], + "std_types_found": [], + "bracket_balance": True, + "compile_checked": False, + "compile_result": None, + } + + code = sample.output + + # Basic syntax checks + syntax_result = self._check_syntax(code) + details["syntax_issues"] = syntax_result["issues"] + details["bracket_balance"] = syntax_result["balanced"] + + if not syntax_result["balanced"]: + errors.append("Unbalanced brackets/braces/parentheses") + + for issue in syntax_result["issues"]: + if self.strict: + errors.append(issue) + else: + warnings.append(issue) + + # Check for keywords and types + details["keywords_found"] = self._find_keywords(code) + details["std_types_found"] = self._find_std_types(code) + + # Compile check if enabled and available + if self.check_compile and self._compiler_available: + compile_result = await self._check_compile(code) + details["compile_checked"] = True + details["compile_result"] = compile_result + + if not compile_result["success"]: + if self.strict: + errors.append(f"Compile error: {compile_result['error'][:200]}") + else: + warnings.append(f"Compile warning: {compile_result['error'][:100]}") + + # Calculate score + score = 1.0 + + # Deduct for syntax issues + score -= len(details["syntax_issues"]) * 0.1 + score = max(0.0, score) + + # Deduct for bracket imbalance + if not details["bracket_balance"]: + score -= 0.3 + + # Bonus for using C++ features + if details["keywords_found"]: + score = min(1.0, score + 0.05) + if details["std_types_found"]: + score = min(1.0, score + 0.05) + + # Deduct for compile failure + if details["compile_checked"] and not details["compile_result"]["success"]: + score -= 0.2 + + score = max(0.0, min(1.0, score)) + + return ValidationResult( + valid=len(errors) == 0, + score=score, + errors=errors, + warnings=warnings, + details=details, + ) + + def _check_syntax(self, code: str) -> dict: + """Check basic C++ syntax.""" + issues = [] + balanced = True + + # Check bracket balance + stack = [] + pairs = {"(": ")", "[": "]", "{": "}"} + in_string = False + in_char = False + in_comment = False + in_block_comment = False + + i = 0 + while i < len(code): + c = code[i] + + # Handle comments + if not in_string and not in_char: + if i < len(code) - 1: + two_char = code[i:i+2] + if two_char == "//": + # Skip to end of line + while i < len(code) and code[i] != "\n": + i += 1 + continue + elif two_char == "/*": + in_block_comment = True + i += 2 + continue + elif two_char == "*/" and in_block_comment: + in_block_comment = False + i += 2 + continue + + if in_block_comment: + i += 1 + continue + + # Handle strings + if c == '"' and not in_char and (i == 0 or code[i-1] != '\\'): + in_string = not in_string + elif c == "'" and not in_string and (i == 0 or code[i-1] != '\\'): + in_char = not in_char + + if not in_string and not in_char: + if c in pairs: + stack.append(c) + elif c in pairs.values(): + if not stack: + balanced = False + issues.append(f"Unexpected closing bracket '{c}'") + else: + expected = pairs[stack.pop()] + if c != expected: + balanced = False + issues.append(f"Mismatched brackets: expected '{expected}', got '{c}'") + + i += 1 + + if stack: + balanced = False + issues.append(f"Unclosed brackets: {stack}") + + # Check for common issues + # Missing semicolons after statements (heuristic) + lines = code.split("\n") + for i, line in enumerate(lines): + stripped = line.strip() + + # Skip empty lines, comments, preprocessor + if not stripped or stripped.startswith("//") or stripped.startswith("#"): + continue + + # Skip lines that end with block characters + if stripped.endswith("{") or stripped.endswith("}") or stripped.endswith(":"): + continue + + # Skip lines that are likely continuations + if stripped.endswith(",") or stripped.endswith("\\"): + continue + + # Check for statements that should end with semicolon + # This is a heuristic and may have false positives + statement_patterns = [ + r"return\s+.+[^;]$", # return without semicolon + r"break$", # break without semicolon + r"continue$", # continue without semicolon + ] + + for pattern in statement_patterns: + if re.search(pattern, stripped): + issues.append(f"Line {i+1}: Possibly missing semicolon") + break + + return {"issues": issues, "balanced": balanced} + + def _find_keywords(self, code: str) -> list[str]: + """Find C++ keywords in code.""" + found = [] + # Use word boundaries to find keywords + for keyword in self.KEYWORDS: + if re.search(rf"\b{keyword}\b", code): + found.append(keyword) + return found + + def _find_std_types(self, code: str) -> list[str]: + """Find standard library types in code.""" + found = [] + for type_name in self.STD_TYPES: + # Check for std::type or just type in common contexts + if re.search(rf"std::{type_name}\b", code) or re.search(rf"\b{type_name}<", code): + found.append(type_name) + return found + + async def _check_compile(self, code: str) -> dict: + """Attempt to compile the code.""" + # Create temporary file + with tempfile.NamedTemporaryFile( + mode="w", suffix=".cpp", delete=False + ) as f: + # Add minimal includes for standalone compilation + wrapped_code = """ +#include +#include +#include +#include + +// Sample code below +""" + code + f.write(wrapped_code) + temp_path = Path(f.name) + + try: + # Run compiler with syntax-only check + process = await asyncio.create_subprocess_exec( + self.compiler, + "-fsyntax-only", + "-std=c++17", + "-Wall", + str(temp_path), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=10.0 + ) + + success = process.returncode == 0 + error = stderr.decode("utf-8", errors="replace") if stderr else "" + + return { + "success": success, + "error": error, + "returncode": process.returncode, + } + + except asyncio.TimeoutError: + return { + "success": False, + "error": "Compilation timed out", + "returncode": -1, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "returncode": -1, + } + finally: + # Clean up temp file + try: + temp_path.unlink() + except Exception: + pass diff --git a/src/afs_scawful/validators/kg_validator.py b/src/afs_scawful/validators/kg_validator.py new file mode 100644 index 0000000..2005edf --- /dev/null +++ b/src/afs_scawful/validators/kg_validator.py @@ -0,0 +1,349 @@ +"""Knowledge Graph Validator for training samples. + +Validates: +- Entity presence in knowledge graph +- Relationship consistency +- Cross-reference validity +- Domain alignment +""" + +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Any, Optional + +from ..training import TrainingSample +from .base import ValidationResult, Validator + + +def _default_graph_path() -> Path: + candidate = Path.home() / "src" / "context" / "memory" / "knowledge_graph.json" + if candidate.exists(): + return candidate + return Path.home() / ".context" / "memory" / "knowledge_graph.json" + + +class KGValidator(Validator): + """Validator for knowledge graph consistency in training samples.""" + + def __init__( + self, + graph_path: Optional[Path] = None, + strict: bool = False, + min_entity_coverage: float = 0.3, + ): + """Initialize KG validator. + + Args: + graph_path: Path to knowledge graph JSON. Defaults to ~/src/context/memory/knowledge_graph.json + (fallback: ~/.context/memory/knowledge_graph.json). + strict: If True, apply stricter validation (missing entities are errors) + min_entity_coverage: Minimum fraction of mentioned entities that must be in KG + """ + super().__init__("KGValidator", "all") # Applies to all domains + self.graph_path = graph_path or _default_graph_path() + self.strict = strict + self.min_entity_coverage = min_entity_coverage + + # Lazy load graph + self._graph: Optional[dict] = None + self._nodes: dict[str, Any] = {} + self._edges: list[dict[str, Any]] = [] + self._node_names: set[str] = set() + self._routines: set[str] = set() + self._symbols: set[str] = set() + + def _load_graph(self) -> None: + """Load knowledge graph from disk.""" + if self._graph is not None: + return + + if not self.graph_path.exists(): + self._graph = {"nodes": {}, "edges": []} + return + + try: + data = json.loads(self.graph_path.read_text()) + self._graph = data + self._nodes = data.get("nodes", {}) + self._edges = data.get("edges", []) + + # Build lookup sets + for node_id, node_data in self._nodes.items(): + self._node_names.add(node_id.lower()) + + # Extract name from node data + if isinstance(node_data, dict): + name = node_data.get("name", "") + if name: + self._node_names.add(name.lower()) + + # Track routines and symbols specifically + node_type = node_data.get("type", "") + if node_type == "routine": + self._routines.add(name.lower()) + elif node_type == "symbol": + self._symbols.add(name.lower()) + + except Exception: + self._graph = {"nodes": {}, "edges": []} + + def can_validate(self, sample: TrainingSample) -> bool: + """KG validator can validate any sample with kg_entities.""" + return True # Applies to all domains + + async def validate(self, sample: TrainingSample) -> ValidationResult: + """Validate knowledge graph consistency in the sample.""" + self._load_graph() + + errors: list[str] = [] + warnings: list[str] = [] + details: dict = { + "entities_mentioned": [], + "entities_found": [], + "entities_missing": [], + "routines_mentioned": [], + "symbols_mentioned": [], + "relationships_valid": True, + "coverage": 0.0, + } + + # Extract entities from sample + text = f"{sample.instruction} {sample.input} {sample.output}" + mentioned = self._extract_entities(text, sample.domain) + + details["entities_mentioned"] = mentioned + + # Check which entities exist in KG + found = [] + missing = [] + + for entity in mentioned: + entity_lower = entity.lower() + if self._entity_exists(entity_lower): + found.append(entity) + else: + missing.append(entity) + + details["entities_found"] = found + details["entities_missing"] = missing + + # Calculate coverage + if mentioned: + coverage = len(found) / len(mentioned) + else: + coverage = 1.0 # No entities to validate + + details["coverage"] = coverage + + # Check for routine/symbol references in ASM samples + if sample.domain.startswith("asm"): + routines = self._extract_routine_references(sample.output) + symbols = self._extract_symbol_references(sample.output) + + details["routines_mentioned"] = routines + details["symbols_mentioned"] = symbols + + # Check routine validity + for routine in routines: + if routine.lower() not in self._routines and routine.lower() not in self._node_names: + if self.strict: + errors.append(f"Unknown routine: {routine}") + else: + warnings.append(f"Routine not in KG: {routine}") + + # Check kg_entities from sample metadata + if sample.kg_entities: + for entity in sample.kg_entities: + if not self._entity_exists(entity.lower()): + if self.strict: + errors.append(f"Tagged entity not in KG: {entity}") + else: + warnings.append(f"Tagged entity not in KG: {entity}") + + # Validate coverage threshold + if coverage < self.min_entity_coverage and mentioned: + msg = f"Entity coverage {coverage:.1%} below threshold {self.min_entity_coverage:.1%}" + if self.strict: + errors.append(msg) + else: + warnings.append(msg) + + # Calculate score + score = 1.0 + + # Base score on coverage + score = min(1.0, coverage + 0.3) # Coverage contributes up to 0.7 + + # Bonus for having KG entities tagged + if sample.kg_entities and sample.kg_validated: + score = min(1.0, score + 0.1) + + # Penalty for missing entities + if missing: + penalty = len(missing) * 0.05 + score = max(0.3, score - penalty) + + return ValidationResult( + valid=len(errors) == 0, + score=score, + errors=errors, + warnings=warnings, + details=details, + ) + + def _entity_exists(self, entity: str) -> bool: + """Check if an entity exists in the knowledge graph.""" + entity_lower = entity.lower() + + # Direct match + if entity_lower in self._node_names: + return True + + # Check with common prefixes + prefixes = ["alttp:", "oracle-of-secrets:", "project:", "routine:", "symbol:"] + for prefix in prefixes: + if f"{prefix}{entity_lower}" in self._node_names: + return True + # Also check node IDs directly + for node_id in self._nodes: + if node_id.lower().endswith(f":{entity_lower}"): + return True + + return False + + def _extract_entities(self, text: str, domain: str) -> list[str]: + """Extract potential entity references from text.""" + entities = [] + + # Common patterns for entity references + patterns = [ + # Code references like `EntityName` or `RoutineName` + r'`([A-Z][a-zA-Z0-9_]+)`', + # Capitalized terms that look like identifiers + r'\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)\b', # CamelCase + # Routine names (common in ASM) + r'\b(Link_[A-Za-z0-9_]+)\b', + r'\b(Player_[A-Za-z0-9_]+)\b', + r'\b(Sprite_[A-Za-z0-9_]+)\b', + r'\b(Module_[A-Za-z0-9_]+)\b', + # Memory addresses with labels + r'\b([A-Z][A-Za-z0-9]+_[A-Z][A-Za-z0-9]+)\b', + ] + + for pattern in patterns: + matches = re.findall(pattern, text) + entities.extend(matches) + + # Domain-specific extraction + if domain.startswith("asm"): + # Extract ASM-specific references + asm_patterns = [ + r'\b([A-Z][a-z]+_[A-Z][a-z_0-9]+)\b', # Link_HandleSword + r'@([A-Za-z_][A-Za-z0-9_]+)', # @Labels + ] + for pattern in asm_patterns: + matches = re.findall(pattern, text) + entities.extend(matches) + + elif domain == "cpp": + # Extract C++ class/function names + cpp_patterns = [ + r'\bclass\s+([A-Z][a-zA-Z0-9_]+)\b', + r'\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)::\w+', # ClassName::method + ] + for pattern in cpp_patterns: + matches = re.findall(pattern, text) + entities.extend(matches) + + # Deduplicate while preserving order + seen = set() + unique = [] + for e in entities: + if e.lower() not in seen: + seen.add(e.lower()) + unique.append(e) + + return unique + + def _extract_routine_references(self, code: str) -> list[str]: + """Extract routine/label references from ASM code.""" + routines = [] + + # JSR/JSL targets + jsr_pattern = r'\b(?:JSR|JSL|JMP|JML)\s+([A-Za-z_][A-Za-z0-9_]+)\b' + matches = re.findall(jsr_pattern, code, re.IGNORECASE) + routines.extend(matches) + + # BRA/BRL targets + branch_pattern = r'\b(?:BRA|BRL|BEQ|BNE|BCC|BCS|BMI|BPL)\s+([A-Za-z_][A-Za-z0-9_]+)\b' + matches = re.findall(branch_pattern, code, re.IGNORECASE) + routines.extend(matches) + + return list(set(routines)) + + def _extract_symbol_references(self, code: str) -> list[str]: + """Extract symbol/variable references from ASM code.""" + symbols = [] + + # LDA/STA with labels + load_store_pattern = r'\b(?:LDA|LDX|LDY|STA|STX|STY)\s+([A-Za-z_][A-Za-z0-9_]+)\b' + matches = re.findall(load_store_pattern, code, re.IGNORECASE) + symbols.extend(matches) + + # Filter out common non-symbol patterns + filtered = [] + for sym in symbols: + # Skip if it looks like a routine name + if sym.lower() in self._routines: + continue + # Skip common mnemonics that might be captured + if sym.upper() in {'A', 'X', 'Y', 'S'}: + continue + filtered.append(sym) + + return list(set(filtered)) + + def get_related_entities(self, entity: str) -> list[dict[str, Any]]: + """Get entities related to a given entity in the KG.""" + self._load_graph() + + related = [] + entity_lower = entity.lower() + + for edge in self._edges: + source = str(edge.get("source", "")).lower() + target = str(edge.get("target", "")).lower() + relation = edge.get("relation", "") + + if entity_lower in source: + related.append({ + "entity": edge.get("target"), + "relation": relation, + "direction": "outgoing", + }) + elif entity_lower in target: + related.append({ + "entity": edge.get("source"), + "relation": relation, + "direction": "incoming", + }) + + return related + + def suggest_entities(self, partial: str, limit: int = 10) -> list[str]: + """Suggest entity names matching a partial string.""" + self._load_graph() + + partial_lower = partial.lower() + matches = [] + + for node_id in self._nodes: + if partial_lower in node_id.lower(): + matches.append(node_id) + if len(matches) >= limit: + break + + return matches diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..20e20b3 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import asyncio + +from afs_scawful.training import TrainingSample +from afs_scawful.validators import AsmValidator, CppValidator + + +def test_asm_validator_basic() -> None: + sample = TrainingSample( + instruction="", + input="", + output="LDA #$01\nSTA $7E0000\n", + domain="asm", + source="test", + ) + result = asyncio.run(AsmValidator().validate(sample)) + assert result.valid + assert result.score > 0.0 + + +def test_cpp_validator_basic() -> None: + sample = TrainingSample( + instruction="", + input="", + output="int main() { return 0; }\n", + domain="cpp", + source="test", + ) + result = asyncio.run(CppValidator(check_compile=False).validate(sample)) + assert result.valid + assert result.score > 0.0