core: add training model and validators

This commit is contained in:
scawful
2025-12-30 13:37:28 -05:00
parent 3de9c302ce
commit a6fd2591dd
9 changed files with 1442 additions and 1 deletions

View File

@@ -3,12 +3,16 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import asyncio
import json
from pathlib import Path from pathlib import Path
from typing import Iterable from typing import Iterable
from .registry import index_datasets, build_dataset_registry, write_dataset_registry from .registry import build_dataset_registry, index_datasets, write_dataset_registry
from .resource_index import ResourceIndexer from .resource_index import ResourceIndexer
from .paths import resolve_datasets_root, resolve_index_root from .paths import resolve_datasets_root, resolve_index_root
from .training import TrainingSample
from .validators import default_validators
def _datasets_index_command(args: argparse.Namespace) -> int: def _datasets_index_command(args: argparse.Namespace) -> int:
@@ -43,6 +47,45 @@ def _resources_index_command(args: argparse.Namespace) -> int:
return 0 return 0
async def _run_validators(sample: TrainingSample, validators) -> list[tuple[str, object]]:
results: list[tuple[str, object]] = []
for validator in validators:
if validator.can_validate(sample):
result = await validator.validate(sample)
results.append((validator.name, result))
return results
def _validators_list_command(args: argparse.Namespace) -> int:
validators = default_validators()
for validator in validators:
print(f"{validator.name}\t{validator.domain}")
return 0
def _validators_run_command(args: argparse.Namespace) -> int:
sample_path = Path(args.sample).expanduser().resolve()
payload = json.loads(sample_path.read_text(encoding="utf-8"))
sample = TrainingSample.from_dict(payload)
validators = default_validators()
if args.name:
validators = [v for v in validators if v.name in args.name]
results = asyncio.run(_run_validators(sample, validators))
if not results:
print("(no validators)")
return 1
overall_ok = True
for name, result in results:
status = "ok" if result.valid else "fail"
if not result.valid:
overall_ok = False
print(f"{name}\t{status}\t{result.score:.2f}")
return 0 if overall_ok else 1
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="afs_scawful") parser = argparse.ArgumentParser(prog="afs_scawful")
subparsers = parser.add_subparsers(dest="command") subparsers = parser.add_subparsers(dest="command")
@@ -77,6 +120,21 @@ def build_parser() -> argparse.ArgumentParser:
resources_index.add_argument("--output", help="Output index path.") resources_index.add_argument("--output", help="Output index path.")
resources_index.set_defaults(func=_resources_index_command) resources_index.set_defaults(func=_resources_index_command)
validators_parser = subparsers.add_parser("validators", help="Validation tools.")
validators_sub = validators_parser.add_subparsers(dest="validators_command")
validators_list = validators_sub.add_parser("list", help="List validators.")
validators_list.set_defaults(func=_validators_list_command)
validators_run = validators_sub.add_parser("run", help="Validate a sample JSON.")
validators_run.add_argument("sample", help="Path to sample JSON.")
validators_run.add_argument(
"--name",
action="append",
help="Validator name to run (repeatable).",
)
validators_run.set_defaults(func=_validators_run_command)
return parser return parser
@@ -92,6 +150,9 @@ def main(argv: Iterable[str] | None = None) -> int:
if args.command == "resources" and not getattr(args, "resources_command", None): if args.command == "resources" and not getattr(args, "resources_command", None):
parser.print_help() parser.print_help()
return 1 return 1
if args.command == "validators" and not getattr(args, "validators_command", None):
parser.print_help()
return 1
return args.func(args) return args.func(args)

View File

@@ -0,0 +1,73 @@
"""Training sample data models for AFS Scawful."""
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class TrainingSample:
instruction: str
input: str
output: str
domain: str
source: str = ""
sample_id: str = ""
timestamp: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
kg_entities: list[str] = field(default_factory=list)
kg_validated: bool = False
def __post_init__(self) -> None:
if not self.sample_id:
self.sample_id = str(uuid.uuid4())
if not self.timestamp:
self.timestamp = datetime.now().isoformat()
def to_dict(self) -> dict[str, Any]:
return {
"instruction": self.instruction,
"input": self.input,
"output": self.output,
"domain": self.domain,
"source": self.source,
"sample_id": self.sample_id,
"timestamp": self.timestamp,
"metadata": self.metadata,
"kg_entities": self.kg_entities,
"kg_validated": self.kg_validated,
}
def to_jsonl_entry(self) -> str:
payload = {
"instruction": self.instruction,
"output": self.output,
}
if self.input:
payload["input"] = self.input
payload["_metadata"] = {
"sample_id": self.sample_id,
"domain": self.domain,
"source": self.source,
"timestamp": self.timestamp,
}
return json.dumps(payload)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TrainingSample":
return cls(
instruction=data.get("instruction", ""),
input=data.get("input", ""),
output=data.get("output", ""),
domain=data.get("domain", ""),
source=data.get("source", ""),
sample_id=data.get("sample_id", ""),
timestamp=data.get("timestamp", ""),
metadata=data.get("metadata", {}) or {},
kg_entities=data.get("kg_entities", []) or [],
kg_validated=bool(data.get("kg_validated", False)),
)

View File

@@ -0,0 +1,27 @@
"""Validator registry for AFS Scawful."""
from .asar_validator import AsarValidator
from .asm_validator import AsmValidator
from .base import CompositeValidator, ValidationResult, Validator
from .cpp_validator import CppValidator
from .kg_validator import KGValidator
__all__ = [
"AsarValidator",
"AsmValidator",
"CppValidator",
"CompositeValidator",
"KGValidator",
"ValidationResult",
"Validator",
"default_validators",
]
def default_validators() -> list[Validator]:
return [
AsmValidator(),
AsarValidator(),
CppValidator(),
KGValidator(),
]

View File

@@ -0,0 +1,127 @@
"""Asar Validator for verifying 65816 assembly code.
Uses the actual 'asar' binary to assemble code snippets against a dummy ROM.
This provides 100% accurate syntax and label validation.
"""
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Optional
from ..training import TrainingSample
from .base import ValidationResult, Validator
logger = logging.getLogger(__name__)
def _resolve_env_path(env_var: str) -> Path | None:
value = os.environ.get(env_var)
if not value:
return None
return Path(value).expanduser().resolve()
def _default_asar_path() -> Path:
env = _resolve_env_path("AFS_ASAR_PATH")
if env:
return env
found = shutil.which("asar")
if found:
return Path(found)
return Path("asar")
def _default_rom_path() -> Path:
env = _resolve_env_path("AFS_ASAR_ROM")
if env:
return env
candidate = Path.home() / "src" / "training" / "roms" / "dummy.sfc"
if candidate.exists():
return candidate
return Path.home() / ".context" / "training" / "dummy.sfc"
class AsarValidator(Validator):
"""Validates assembly code by running it through Asar."""
def __init__(self, asar_path: Path | None = None, rom_path: Path | None = None):
super().__init__("AsarValidator", "asm")
self.asar_path = asar_path or _default_asar_path()
self.rom_path = rom_path or _default_rom_path()
if not self.asar_path.exists():
logger.warning("Asar binary not found at %s", self.asar_path)
if not self.rom_path.exists():
logger.warning("Dummy ROM not found at %s", self.rom_path)
async def validate(self, sample: TrainingSample) -> ValidationResult:
"""Run asar on the sample output code."""
if not self.asar_path.exists() or not self.rom_path.exists():
return ValidationResult(
valid=True,
score=0.5,
warnings=["Asar validator skipped: binary or ROM missing"],
)
# Extract code (simple heuristic: look for code blocks or use full output)
code = self._extract_code(sample.output)
if not code:
return ValidationResult(valid=False, score=0.0, errors=["No code found"])
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
source_file = tmp_path / "test.asm"
rom_file = tmp_path / "test.sfc"
# Copy dummy ROM to temp (to avoid modifying the original)
shutil.copy(self.rom_path, rom_file)
# Wrap code in a safe patch structure
# We assume the code is a snippet, so we hook it into free space
wrapped_code = (
"lorom\n"
"org $008000\n" # Hook into start of ROM
f"{code}\n"
)
source_file.write_text(wrapped_code)
# Run asar
proc = await asyncio.create_subprocess_exec(
str(self.asar_path),
str(source_file),
str(rom_file),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode == 0:
return ValidationResult(valid=True, score=1.0)
else:
error_msg = stderr.decode() + stdout.decode()
# Clean up error message
lines = [l for l in error_msg.split('\n') if "error:" in l.lower()]
return ValidationResult(
valid=False,
score=0.0,
errors=lines[:3] or ["Asar failed to assemble"],
)
def _extract_code(self, text: str) -> str:
"""Extract ASM code from markdown block or raw text."""
if "```asm" in text:
parts = text.split("```asm")
if len(parts) > 1:
return parts[1].split("```")[0].strip()
if "```" in text:
parts = text.split("```")
if len(parts) > 1:
return parts[1].strip()
return text # Assume raw code if no blocks

View File

@@ -0,0 +1,342 @@
"""ASM Validator for 65816 assembly training samples.
Validates:
- Instruction mnemonics
- Addressing modes
- Register usage
- Memory addressing patterns
- SNES-specific constructs
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
from ..training import TrainingSample
from .base import ValidationResult, Validator
@dataclass
class InstructionInfo:
"""Information about a 65816 instruction."""
mnemonic: str
addressing_modes: list[str]
description: str
class AsmValidator(Validator):
"""Validator for 65816 assembly code in training samples."""
# Valid 65816 instruction mnemonics
VALID_MNEMONICS = {
# Load/Store
"LDA", "LDX", "LDY", "STA", "STX", "STY", "STZ",
# Transfer
"TAX", "TAY", "TXA", "TYA", "TXS", "TSX", "TCD", "TDC", "TCS", "TSC", "TXY", "TYX",
# Stack
"PHA", "PHP", "PHX", "PHY", "PHB", "PHD", "PHK",
"PLA", "PLP", "PLX", "PLY", "PLB", "PLD",
"PEA", "PEI", "PER",
# Arithmetic
"ADC", "SBC", "INC", "INX", "INY", "DEC", "DEX", "DEY",
# Comparison
"CMP", "CPX", "CPY",
# Logical
"AND", "ORA", "EOR", "BIT",
# Shift/Rotate
"ASL", "LSR", "ROL", "ROR",
# Branch
"BCC", "BCS", "BEQ", "BMI", "BNE", "BPL", "BVC", "BVS", "BRA", "BRL",
# Jump
"JMP", "JML", "JSR", "JSL", "RTS", "RTL", "RTI",
# Flags
"CLC", "CLD", "CLI", "CLV", "SEC", "SED", "SEI",
"REP", "SEP",
# Processor
"NOP", "WDM", "STP", "WAI", "XBA", "XCE",
# Block Move
"MVP", "MVN",
# Misc
"BRK", "COP", "WDM",
# 65C816 specific
"TRB", "TSB",
}
# Valid addressing mode patterns
ADDRESSING_PATTERNS = {
"immediate_8": r"#\$[0-9A-Fa-f]{1,2}", # #$XX
"immediate_16": r"#\$[0-9A-Fa-f]{3,4}", # #$XXXX
"immediate_symbol": r"#[A-Za-z_]\w*", # #SYMBOL
"direct_page": r"\$[0-9A-Fa-f]{1,2}(?!\w)", # $XX (not followed by more hex)
"absolute": r"\$[0-9A-Fa-f]{4}(?!\w)", # $XXXX
"long": r"\$[0-9A-Fa-f]{6}", # $XXXXXX
"indexed_x": r",\s*[Xx]", # ,X
"indexed_y": r",\s*[Yy]", # ,Y
"indirect": r"\([^)]+\)", # (...)
"stack_relative": r"\$[0-9A-Fa-f]{1,2},\s*[Ss]", # $XX,S
"accumulator": r"[Aa](?:\s|$)", # A
"label": r"[A-Za-z_]\w*", # Labels
}
# SNES-specific registers and addresses
SNES_REGISTERS = {
# PPU Registers
"INIDISP", "OBSEL", "OAMADDL", "OAMADDH", "OAMDATA",
"BGMODE", "MOSAIC", "BG1SC", "BG2SC", "BG3SC", "BG4SC",
"BG12NBA", "BG34NBA", "BG1HOFS", "BG1VOFS", "BG2HOFS", "BG2VOFS",
"BG3HOFS", "BG3VOFS", "BG4HOFS", "BG4VOFS",
"VMAIN", "VMADDL", "VMADDH", "VMDATAL", "VMDATAH",
"M7SEL", "M7A", "M7B", "M7C", "M7D", "M7X", "M7Y",
"CGADD", "CGDATA", "W12SEL", "W34SEL", "WOBJSEL",
"WH0", "WH1", "WH2", "WH3", "WBGLOG", "WOBJLOG",
"TM", "TS", "TMW", "TSW", "CGWSEL", "CGADSUB",
"COLDATA", "SETINI",
# APU Registers
"APUIO0", "APUIO1", "APUIO2", "APUIO3",
# DMA Registers
"MDMAEN", "HDMAEN", "MEMSEL",
# CPU Registers
"NMITIMEN", "WRIO", "WRMPYA", "WRMPYB", "WRDIVL", "WRDIVH",
"WRDIVB", "HTIMEL", "HTIMEH", "VTIMEL", "VTIMEH",
"RDNMI", "TIMEUP", "HVBJOY", "RDIO", "RDDIVL", "RDDIVH",
"RDMPYL", "RDMPYH", "JOY1L", "JOY1H", "JOY2L", "JOY2H",
"JOY3L", "JOY3H", "JOY4L", "JOY4H",
}
# Common ALTTP-specific labels
ALTTP_LABELS = {
"Module", "Submodule", "Link", "Player", "Sprite",
"WRAM", "SRAM", "VRAM", "OAM", "CGRAM",
}
VALID_DOMAINS = {"asm", "hack_curated"}
def __init__(self, strict: bool = False):
"""Initialize ASM validator.
Args:
strict: If True, apply stricter validation rules
"""
super().__init__("AsmValidator", "asm")
self.strict = strict
def can_validate(self, sample: TrainingSample) -> bool:
"""Allow ASM validation for curated hack samples too."""
return sample.domain in self.VALID_DOMAINS or sample.domain.startswith("asm")
async def validate(self, sample: TrainingSample) -> ValidationResult:
"""Validate 65816 assembly in the sample output."""
errors: list[str] = []
warnings: list[str] = []
details: dict = {
"instructions_found": 0,
"valid_instructions": 0,
"invalid_instructions": [],
"snes_registers_used": [],
"addressing_modes": [],
}
# Extract code from output
code = sample.output
# Parse instructions
instructions = self._extract_instructions(code)
details["instructions_found"] = len(instructions)
if len(instructions) == 0:
warnings.append("No assembly instructions found in output")
return ValidationResult(
valid=True,
score=0.5,
warnings=warnings,
details=details,
)
# Validate each instruction
for line_num, instr in instructions:
result = self._validate_instruction(instr)
if result.valid:
details["valid_instructions"] += 1
if result.addressing_mode:
details["addressing_modes"].append(result.addressing_mode)
else:
details["invalid_instructions"].append({
"line": line_num,
"instruction": instr,
"error": result.error,
})
if self.strict:
errors.append(f"Line {line_num}: {result.error}")
else:
warnings.append(f"Line {line_num}: {result.error}")
# Check for SNES registers
for reg in self.SNES_REGISTERS:
if reg in code:
details["snes_registers_used"].append(reg)
# Calculate score
if details["instructions_found"] > 0:
score = details["valid_instructions"] / details["instructions_found"]
else:
score = 0.5
# Boost score if SNES-specific content found
if details["snes_registers_used"]:
score = min(1.0, score + 0.1)
return ValidationResult(
valid=len(errors) == 0,
score=score,
errors=errors,
warnings=warnings,
details=details,
)
def _extract_instructions(self, code: str) -> list[tuple[int, str]]:
"""Extract assembly instructions from code.
Returns:
List of (line_number, instruction) tuples
"""
instructions = []
lines = code.split("\n")
for i, line in enumerate(lines, 1):
# Remove comments
if ";" in line:
line = line[:line.index(";")]
# Remove labels (lines ending with :)
if ":" in line:
# Check if it's a label definition
parts = line.split(":")
if len(parts) > 1:
line = parts[-1]
# Remove address prefixes like #_008000:
line = re.sub(r"#_[0-9A-Fa-f]+:\s*", "", line)
line = line.strip()
if not line:
continue
# Check if line starts with a valid mnemonic
parts = line.split()
if parts:
mnemonic = parts[0].upper()
if mnemonic in self.VALID_MNEMONICS:
instructions.append((i, line))
elif re.match(r"[A-Za-z]{2,4}", mnemonic):
# Might be an instruction-like thing
instructions.append((i, line))
return instructions
def _validate_instruction(self, instruction: str) -> "_InstructionValidation":
"""Validate a single instruction."""
parts = instruction.split(None, 1)
if not parts:
return _InstructionValidation(False, "Empty instruction")
mnemonic = parts[0].upper()
operand = parts[1] if len(parts) > 1 else ""
# Check mnemonic
if mnemonic not in self.VALID_MNEMONICS:
# Check if it's close to a valid mnemonic (typo detection)
close_matches = [m for m in self.VALID_MNEMONICS
if self._levenshtein_distance(mnemonic, m) <= 1]
if close_matches:
return _InstructionValidation(
False,
f"Unknown mnemonic '{mnemonic}' (did you mean {close_matches[0]}?)"
)
return _InstructionValidation(False, f"Unknown mnemonic '{mnemonic}'")
# Validate operand if present
addressing_mode = None
if operand:
addressing_mode = self._detect_addressing_mode(operand)
return _InstructionValidation(True, None, addressing_mode)
def _detect_addressing_mode(self, operand: str) -> Optional[str]:
"""Detect the addressing mode from the operand."""
operand = operand.strip()
# Check patterns in order of specificity
if re.match(r"#", operand):
if re.search(r"#\$[0-9A-Fa-f]{3,4}", operand):
return "immediate_16"
elif re.search(r"#\$[0-9A-Fa-f]{1,2}", operand):
return "immediate_8"
else:
return "immediate_symbol"
if re.search(r",\s*[Ss]", operand):
return "stack_relative"
if re.match(r"\([^)]+\)", operand):
if ",X" in operand.upper():
return "indexed_indirect_x"
elif ",Y" in operand.upper():
return "indirect_indexed_y"
else:
return "indirect"
if re.search(r",\s*[Xx]", operand):
return "indexed_x"
if re.search(r",\s*[Yy]", operand):
return "indexed_y"
if re.match(r"\$[0-9A-Fa-f]{6}", operand):
return "long"
if re.match(r"\$[0-9A-Fa-f]{4}", operand):
return "absolute"
if re.match(r"\$[0-9A-Fa-f]{1,2}(?!\w)", operand):
return "direct_page"
if re.match(r"[Aa]$", operand):
return "accumulator"
if re.match(r"[A-Za-z_]\w*", operand):
return "label"
return None
def _levenshtein_distance(self, s1: str, s2: str) -> int:
"""Calculate Levenshtein distance between two strings."""
if len(s1) < len(s2):
return self._levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
@dataclass
class _InstructionValidation:
"""Internal result of validating a single instruction."""
valid: bool
error: Optional[str] = None
addressing_mode: Optional[str] = None

View File

@@ -0,0 +1,90 @@
"""Base validator interfaces for AFS Scawful."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from ..training import TrainingSample
@dataclass
class ValidationResult:
valid: bool
score: float
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
details: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"valid": self.valid,
"score": self.score,
"errors": list(self.errors),
"warnings": list(self.warnings),
"details": dict(self.details),
}
class Validator(ABC):
def __init__(self, name: str, domain: str) -> None:
self.name = name
self.domain = domain
@abstractmethod
async def validate(self, sample: TrainingSample) -> ValidationResult:
raise NotImplementedError
def can_validate(self, sample: TrainingSample) -> bool:
return sample.domain == self.domain
async def validate_batch(self, samples: list[TrainingSample]) -> list[ValidationResult]:
results: list[ValidationResult] = []
for sample in samples:
if self.can_validate(sample):
results.append(await self.validate(sample))
else:
results.append(
ValidationResult(
valid=True,
score=1.0,
warnings=[f"{self.name} skipped: domain mismatch"],
)
)
return results
class CompositeValidator(Validator):
def __init__(self, validators: list[Validator]) -> None:
super().__init__("CompositeValidator", "all")
self.validators = validators
def can_validate(self, sample: TrainingSample) -> bool:
return any(validator.can_validate(sample) for validator in self.validators)
async def validate(self, sample: TrainingSample) -> ValidationResult:
applicable = [v for v in self.validators if v.can_validate(sample)]
if not applicable:
return ValidationResult(valid=True, score=1.0, warnings=["No applicable validators"])
errors: list[str] = []
warnings: list[str] = []
details: dict[str, Any] = {}
scores: list[float] = []
for validator in applicable:
result = await validator.validate(sample)
errors.extend(result.errors)
warnings.extend(result.warnings)
details[validator.name] = result.to_dict()
scores.append(result.score)
score = sum(scores) / len(scores) if scores else 1.0
return ValidationResult(
valid=len(errors) == 0,
score=score,
errors=errors,
warnings=warnings,
details=details,
)

View File

@@ -0,0 +1,340 @@
"""C++ Validator for training samples.
Validates:
- Basic syntax checks (brackets, braces, semicolons)
- Keyword usage
- Common patterns
- Optional: Compile check with clang (if available)
"""
from __future__ import annotations
import asyncio
import re
import shutil
import tempfile
from pathlib import Path
from typing import Optional
from ..training import TrainingSample
from .base import ValidationResult, Validator
class CppValidator(Validator):
"""Validator for C++ code in training samples."""
# C++ keywords
KEYWORDS = {
# Storage class
"auto", "register", "static", "extern", "mutable", "thread_local",
# Type specifiers
"void", "bool", "char", "short", "int", "long", "float", "double",
"signed", "unsigned", "wchar_t", "char8_t", "char16_t", "char32_t",
# Type qualifiers
"const", "volatile", "constexpr", "consteval", "constinit",
# Control flow
"if", "else", "switch", "case", "default", "while", "do", "for",
"break", "continue", "return", "goto",
# Declarations
"class", "struct", "union", "enum", "typedef", "using", "namespace",
"template", "typename", "concept", "requires",
# Access specifiers
"public", "private", "protected",
# Other keywords
"virtual", "override", "final", "explicit", "inline", "friend",
"operator", "sizeof", "alignof", "decltype", "typeid",
"new", "delete", "this", "nullptr", "true", "false",
"try", "catch", "throw", "noexcept",
"static_assert", "static_cast", "dynamic_cast", "const_cast", "reinterpret_cast",
"co_await", "co_return", "co_yield",
# Modules (C++20)
"module", "import", "export",
}
# Common C++ standard library types
STD_TYPES = {
"string", "vector", "map", "unordered_map", "set", "unordered_set",
"list", "deque", "array", "pair", "tuple", "optional", "variant",
"shared_ptr", "unique_ptr", "weak_ptr", "function", "any",
"thread", "mutex", "lock_guard", "unique_lock", "condition_variable",
"future", "promise", "async", "atomic",
"ifstream", "ofstream", "fstream", "stringstream", "ostringstream",
"iostream", "cin", "cout", "cerr", "endl",
"size_t", "ptrdiff_t", "nullptr_t", "byte",
"int8_t", "int16_t", "int32_t", "int64_t",
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
}
def __init__(
self,
check_compile: bool = False,
compiler: str = "clang++",
strict: bool = False,
):
"""Initialize C++ validator.
Args:
check_compile: If True, attempt to compile the code
compiler: Compiler to use for compile checks
strict: If True, apply stricter validation
"""
super().__init__("CppValidator", "cpp")
self.check_compile = check_compile
self.compiler = compiler
self.strict = strict
# Check if compiler is available
self._compiler_available = shutil.which(compiler) is not None
async def validate(self, sample: TrainingSample) -> ValidationResult:
"""Validate C++ code in the sample output."""
errors: list[str] = []
warnings: list[str] = []
details: dict = {
"syntax_issues": [],
"keywords_found": [],
"std_types_found": [],
"bracket_balance": True,
"compile_checked": False,
"compile_result": None,
}
code = sample.output
# Basic syntax checks
syntax_result = self._check_syntax(code)
details["syntax_issues"] = syntax_result["issues"]
details["bracket_balance"] = syntax_result["balanced"]
if not syntax_result["balanced"]:
errors.append("Unbalanced brackets/braces/parentheses")
for issue in syntax_result["issues"]:
if self.strict:
errors.append(issue)
else:
warnings.append(issue)
# Check for keywords and types
details["keywords_found"] = self._find_keywords(code)
details["std_types_found"] = self._find_std_types(code)
# Compile check if enabled and available
if self.check_compile and self._compiler_available:
compile_result = await self._check_compile(code)
details["compile_checked"] = True
details["compile_result"] = compile_result
if not compile_result["success"]:
if self.strict:
errors.append(f"Compile error: {compile_result['error'][:200]}")
else:
warnings.append(f"Compile warning: {compile_result['error'][:100]}")
# Calculate score
score = 1.0
# Deduct for syntax issues
score -= len(details["syntax_issues"]) * 0.1
score = max(0.0, score)
# Deduct for bracket imbalance
if not details["bracket_balance"]:
score -= 0.3
# Bonus for using C++ features
if details["keywords_found"]:
score = min(1.0, score + 0.05)
if details["std_types_found"]:
score = min(1.0, score + 0.05)
# Deduct for compile failure
if details["compile_checked"] and not details["compile_result"]["success"]:
score -= 0.2
score = max(0.0, min(1.0, score))
return ValidationResult(
valid=len(errors) == 0,
score=score,
errors=errors,
warnings=warnings,
details=details,
)
def _check_syntax(self, code: str) -> dict:
"""Check basic C++ syntax."""
issues = []
balanced = True
# Check bracket balance
stack = []
pairs = {"(": ")", "[": "]", "{": "}"}
in_string = False
in_char = False
in_comment = False
in_block_comment = False
i = 0
while i < len(code):
c = code[i]
# Handle comments
if not in_string and not in_char:
if i < len(code) - 1:
two_char = code[i:i+2]
if two_char == "//":
# Skip to end of line
while i < len(code) and code[i] != "\n":
i += 1
continue
elif two_char == "/*":
in_block_comment = True
i += 2
continue
elif two_char == "*/" and in_block_comment:
in_block_comment = False
i += 2
continue
if in_block_comment:
i += 1
continue
# Handle strings
if c == '"' and not in_char and (i == 0 or code[i-1] != '\\'):
in_string = not in_string
elif c == "'" and not in_string and (i == 0 or code[i-1] != '\\'):
in_char = not in_char
if not in_string and not in_char:
if c in pairs:
stack.append(c)
elif c in pairs.values():
if not stack:
balanced = False
issues.append(f"Unexpected closing bracket '{c}'")
else:
expected = pairs[stack.pop()]
if c != expected:
balanced = False
issues.append(f"Mismatched brackets: expected '{expected}', got '{c}'")
i += 1
if stack:
balanced = False
issues.append(f"Unclosed brackets: {stack}")
# Check for common issues
# Missing semicolons after statements (heuristic)
lines = code.split("\n")
for i, line in enumerate(lines):
stripped = line.strip()
# Skip empty lines, comments, preprocessor
if not stripped or stripped.startswith("//") or stripped.startswith("#"):
continue
# Skip lines that end with block characters
if stripped.endswith("{") or stripped.endswith("}") or stripped.endswith(":"):
continue
# Skip lines that are likely continuations
if stripped.endswith(",") or stripped.endswith("\\"):
continue
# Check for statements that should end with semicolon
# This is a heuristic and may have false positives
statement_patterns = [
r"return\s+.+[^;]$", # return without semicolon
r"break$", # break without semicolon
r"continue$", # continue without semicolon
]
for pattern in statement_patterns:
if re.search(pattern, stripped):
issues.append(f"Line {i+1}: Possibly missing semicolon")
break
return {"issues": issues, "balanced": balanced}
def _find_keywords(self, code: str) -> list[str]:
"""Find C++ keywords in code."""
found = []
# Use word boundaries to find keywords
for keyword in self.KEYWORDS:
if re.search(rf"\b{keyword}\b", code):
found.append(keyword)
return found
def _find_std_types(self, code: str) -> list[str]:
"""Find standard library types in code."""
found = []
for type_name in self.STD_TYPES:
# Check for std::type or just type in common contexts
if re.search(rf"std::{type_name}\b", code) or re.search(rf"\b{type_name}<", code):
found.append(type_name)
return found
async def _check_compile(self, code: str) -> dict:
"""Attempt to compile the code."""
# Create temporary file
with tempfile.NamedTemporaryFile(
mode="w", suffix=".cpp", delete=False
) as f:
# Add minimal includes for standalone compilation
wrapped_code = """
#include <cstdint>
#include <string>
#include <vector>
#include <memory>
// Sample code below
""" + code
f.write(wrapped_code)
temp_path = Path(f.name)
try:
# Run compiler with syntax-only check
process = await asyncio.create_subprocess_exec(
self.compiler,
"-fsyntax-only",
"-std=c++17",
"-Wall",
str(temp_path),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=10.0
)
success = process.returncode == 0
error = stderr.decode("utf-8", errors="replace") if stderr else ""
return {
"success": success,
"error": error,
"returncode": process.returncode,
}
except asyncio.TimeoutError:
return {
"success": False,
"error": "Compilation timed out",
"returncode": -1,
}
except Exception as e:
return {
"success": False,
"error": str(e),
"returncode": -1,
}
finally:
# Clean up temp file
try:
temp_path.unlink()
except Exception:
pass

View File

@@ -0,0 +1,349 @@
"""Knowledge Graph Validator for training samples.
Validates:
- Entity presence in knowledge graph
- Relationship consistency
- Cross-reference validity
- Domain alignment
"""
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Any, Optional
from ..training import TrainingSample
from .base import ValidationResult, Validator
def _default_graph_path() -> Path:
candidate = Path.home() / "src" / "context" / "memory" / "knowledge_graph.json"
if candidate.exists():
return candidate
return Path.home() / ".context" / "memory" / "knowledge_graph.json"
class KGValidator(Validator):
"""Validator for knowledge graph consistency in training samples."""
def __init__(
self,
graph_path: Optional[Path] = None,
strict: bool = False,
min_entity_coverage: float = 0.3,
):
"""Initialize KG validator.
Args:
graph_path: Path to knowledge graph JSON. Defaults to ~/src/context/memory/knowledge_graph.json
(fallback: ~/.context/memory/knowledge_graph.json).
strict: If True, apply stricter validation (missing entities are errors)
min_entity_coverage: Minimum fraction of mentioned entities that must be in KG
"""
super().__init__("KGValidator", "all") # Applies to all domains
self.graph_path = graph_path or _default_graph_path()
self.strict = strict
self.min_entity_coverage = min_entity_coverage
# Lazy load graph
self._graph: Optional[dict] = None
self._nodes: dict[str, Any] = {}
self._edges: list[dict[str, Any]] = []
self._node_names: set[str] = set()
self._routines: set[str] = set()
self._symbols: set[str] = set()
def _load_graph(self) -> None:
"""Load knowledge graph from disk."""
if self._graph is not None:
return
if not self.graph_path.exists():
self._graph = {"nodes": {}, "edges": []}
return
try:
data = json.loads(self.graph_path.read_text())
self._graph = data
self._nodes = data.get("nodes", {})
self._edges = data.get("edges", [])
# Build lookup sets
for node_id, node_data in self._nodes.items():
self._node_names.add(node_id.lower())
# Extract name from node data
if isinstance(node_data, dict):
name = node_data.get("name", "")
if name:
self._node_names.add(name.lower())
# Track routines and symbols specifically
node_type = node_data.get("type", "")
if node_type == "routine":
self._routines.add(name.lower())
elif node_type == "symbol":
self._symbols.add(name.lower())
except Exception:
self._graph = {"nodes": {}, "edges": []}
def can_validate(self, sample: TrainingSample) -> bool:
"""KG validator can validate any sample with kg_entities."""
return True # Applies to all domains
async def validate(self, sample: TrainingSample) -> ValidationResult:
"""Validate knowledge graph consistency in the sample."""
self._load_graph()
errors: list[str] = []
warnings: list[str] = []
details: dict = {
"entities_mentioned": [],
"entities_found": [],
"entities_missing": [],
"routines_mentioned": [],
"symbols_mentioned": [],
"relationships_valid": True,
"coverage": 0.0,
}
# Extract entities from sample
text = f"{sample.instruction} {sample.input} {sample.output}"
mentioned = self._extract_entities(text, sample.domain)
details["entities_mentioned"] = mentioned
# Check which entities exist in KG
found = []
missing = []
for entity in mentioned:
entity_lower = entity.lower()
if self._entity_exists(entity_lower):
found.append(entity)
else:
missing.append(entity)
details["entities_found"] = found
details["entities_missing"] = missing
# Calculate coverage
if mentioned:
coverage = len(found) / len(mentioned)
else:
coverage = 1.0 # No entities to validate
details["coverage"] = coverage
# Check for routine/symbol references in ASM samples
if sample.domain.startswith("asm"):
routines = self._extract_routine_references(sample.output)
symbols = self._extract_symbol_references(sample.output)
details["routines_mentioned"] = routines
details["symbols_mentioned"] = symbols
# Check routine validity
for routine in routines:
if routine.lower() not in self._routines and routine.lower() not in self._node_names:
if self.strict:
errors.append(f"Unknown routine: {routine}")
else:
warnings.append(f"Routine not in KG: {routine}")
# Check kg_entities from sample metadata
if sample.kg_entities:
for entity in sample.kg_entities:
if not self._entity_exists(entity.lower()):
if self.strict:
errors.append(f"Tagged entity not in KG: {entity}")
else:
warnings.append(f"Tagged entity not in KG: {entity}")
# Validate coverage threshold
if coverage < self.min_entity_coverage and mentioned:
msg = f"Entity coverage {coverage:.1%} below threshold {self.min_entity_coverage:.1%}"
if self.strict:
errors.append(msg)
else:
warnings.append(msg)
# Calculate score
score = 1.0
# Base score on coverage
score = min(1.0, coverage + 0.3) # Coverage contributes up to 0.7
# Bonus for having KG entities tagged
if sample.kg_entities and sample.kg_validated:
score = min(1.0, score + 0.1)
# Penalty for missing entities
if missing:
penalty = len(missing) * 0.05
score = max(0.3, score - penalty)
return ValidationResult(
valid=len(errors) == 0,
score=score,
errors=errors,
warnings=warnings,
details=details,
)
def _entity_exists(self, entity: str) -> bool:
"""Check if an entity exists in the knowledge graph."""
entity_lower = entity.lower()
# Direct match
if entity_lower in self._node_names:
return True
# Check with common prefixes
prefixes = ["alttp:", "oracle-of-secrets:", "project:", "routine:", "symbol:"]
for prefix in prefixes:
if f"{prefix}{entity_lower}" in self._node_names:
return True
# Also check node IDs directly
for node_id in self._nodes:
if node_id.lower().endswith(f":{entity_lower}"):
return True
return False
def _extract_entities(self, text: str, domain: str) -> list[str]:
"""Extract potential entity references from text."""
entities = []
# Common patterns for entity references
patterns = [
# Code references like `EntityName` or `RoutineName`
r'`([A-Z][a-zA-Z0-9_]+)`',
# Capitalized terms that look like identifiers
r'\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)\b', # CamelCase
# Routine names (common in ASM)
r'\b(Link_[A-Za-z0-9_]+)\b',
r'\b(Player_[A-Za-z0-9_]+)\b',
r'\b(Sprite_[A-Za-z0-9_]+)\b',
r'\b(Module_[A-Za-z0-9_]+)\b',
# Memory addresses with labels
r'\b([A-Z][A-Za-z0-9]+_[A-Z][A-Za-z0-9]+)\b',
]
for pattern in patterns:
matches = re.findall(pattern, text)
entities.extend(matches)
# Domain-specific extraction
if domain.startswith("asm"):
# Extract ASM-specific references
asm_patterns = [
r'\b([A-Z][a-z]+_[A-Z][a-z_0-9]+)\b', # Link_HandleSword
r'@([A-Za-z_][A-Za-z0-9_]+)', # @Labels
]
for pattern in asm_patterns:
matches = re.findall(pattern, text)
entities.extend(matches)
elif domain == "cpp":
# Extract C++ class/function names
cpp_patterns = [
r'\bclass\s+([A-Z][a-zA-Z0-9_]+)\b',
r'\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)::\w+', # ClassName::method
]
for pattern in cpp_patterns:
matches = re.findall(pattern, text)
entities.extend(matches)
# Deduplicate while preserving order
seen = set()
unique = []
for e in entities:
if e.lower() not in seen:
seen.add(e.lower())
unique.append(e)
return unique
def _extract_routine_references(self, code: str) -> list[str]:
"""Extract routine/label references from ASM code."""
routines = []
# JSR/JSL targets
jsr_pattern = r'\b(?:JSR|JSL|JMP|JML)\s+([A-Za-z_][A-Za-z0-9_]+)\b'
matches = re.findall(jsr_pattern, code, re.IGNORECASE)
routines.extend(matches)
# BRA/BRL targets
branch_pattern = r'\b(?:BRA|BRL|BEQ|BNE|BCC|BCS|BMI|BPL)\s+([A-Za-z_][A-Za-z0-9_]+)\b'
matches = re.findall(branch_pattern, code, re.IGNORECASE)
routines.extend(matches)
return list(set(routines))
def _extract_symbol_references(self, code: str) -> list[str]:
"""Extract symbol/variable references from ASM code."""
symbols = []
# LDA/STA with labels
load_store_pattern = r'\b(?:LDA|LDX|LDY|STA|STX|STY)\s+([A-Za-z_][A-Za-z0-9_]+)\b'
matches = re.findall(load_store_pattern, code, re.IGNORECASE)
symbols.extend(matches)
# Filter out common non-symbol patterns
filtered = []
for sym in symbols:
# Skip if it looks like a routine name
if sym.lower() in self._routines:
continue
# Skip common mnemonics that might be captured
if sym.upper() in {'A', 'X', 'Y', 'S'}:
continue
filtered.append(sym)
return list(set(filtered))
def get_related_entities(self, entity: str) -> list[dict[str, Any]]:
"""Get entities related to a given entity in the KG."""
self._load_graph()
related = []
entity_lower = entity.lower()
for edge in self._edges:
source = str(edge.get("source", "")).lower()
target = str(edge.get("target", "")).lower()
relation = edge.get("relation", "")
if entity_lower in source:
related.append({
"entity": edge.get("target"),
"relation": relation,
"direction": "outgoing",
})
elif entity_lower in target:
related.append({
"entity": edge.get("source"),
"relation": relation,
"direction": "incoming",
})
return related
def suggest_entities(self, partial: str, limit: int = 10) -> list[str]:
"""Suggest entity names matching a partial string."""
self._load_graph()
partial_lower = partial.lower()
matches = []
for node_id in self._nodes:
if partial_lower in node_id.lower():
matches.append(node_id)
if len(matches) >= limit:
break
return matches

32
tests/test_validators.py Normal file
View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import asyncio
from afs_scawful.training import TrainingSample
from afs_scawful.validators import AsmValidator, CppValidator
def test_asm_validator_basic() -> None:
sample = TrainingSample(
instruction="",
input="",
output="LDA #$01\nSTA $7E0000\n",
domain="asm",
source="test",
)
result = asyncio.run(AsmValidator().validate(sample))
assert result.valid
assert result.score > 0.0
def test_cpp_validator_basic() -> None:
sample = TrainingSample(
instruction="",
input="",
output="int main() { return 0; }\n",
domain="cpp",
source="test",
)
result = asyncio.run(CppValidator(check_compile=False).validate(sample))
assert result.valid
assert result.score > 0.0