core: add training model and validators
This commit is contained in:
@@ -3,12 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from .registry import index_datasets, build_dataset_registry, write_dataset_registry
|
||||
from .registry import build_dataset_registry, index_datasets, write_dataset_registry
|
||||
from .resource_index import ResourceIndexer
|
||||
from .paths import resolve_datasets_root, resolve_index_root
|
||||
from .training import TrainingSample
|
||||
from .validators import default_validators
|
||||
|
||||
|
||||
def _datasets_index_command(args: argparse.Namespace) -> int:
|
||||
@@ -43,6 +47,45 @@ def _resources_index_command(args: argparse.Namespace) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
async def _run_validators(sample: TrainingSample, validators) -> list[tuple[str, object]]:
|
||||
results: list[tuple[str, object]] = []
|
||||
for validator in validators:
|
||||
if validator.can_validate(sample):
|
||||
result = await validator.validate(sample)
|
||||
results.append((validator.name, result))
|
||||
return results
|
||||
|
||||
|
||||
def _validators_list_command(args: argparse.Namespace) -> int:
|
||||
validators = default_validators()
|
||||
for validator in validators:
|
||||
print(f"{validator.name}\t{validator.domain}")
|
||||
return 0
|
||||
|
||||
|
||||
def _validators_run_command(args: argparse.Namespace) -> int:
|
||||
sample_path = Path(args.sample).expanduser().resolve()
|
||||
payload = json.loads(sample_path.read_text(encoding="utf-8"))
|
||||
sample = TrainingSample.from_dict(payload)
|
||||
|
||||
validators = default_validators()
|
||||
if args.name:
|
||||
validators = [v for v in validators if v.name in args.name]
|
||||
|
||||
results = asyncio.run(_run_validators(sample, validators))
|
||||
if not results:
|
||||
print("(no validators)")
|
||||
return 1
|
||||
|
||||
overall_ok = True
|
||||
for name, result in results:
|
||||
status = "ok" if result.valid else "fail"
|
||||
if not result.valid:
|
||||
overall_ok = False
|
||||
print(f"{name}\t{status}\t{result.score:.2f}")
|
||||
return 0 if overall_ok else 1
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="afs_scawful")
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
@@ -77,6 +120,21 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
resources_index.add_argument("--output", help="Output index path.")
|
||||
resources_index.set_defaults(func=_resources_index_command)
|
||||
|
||||
validators_parser = subparsers.add_parser("validators", help="Validation tools.")
|
||||
validators_sub = validators_parser.add_subparsers(dest="validators_command")
|
||||
|
||||
validators_list = validators_sub.add_parser("list", help="List validators.")
|
||||
validators_list.set_defaults(func=_validators_list_command)
|
||||
|
||||
validators_run = validators_sub.add_parser("run", help="Validate a sample JSON.")
|
||||
validators_run.add_argument("sample", help="Path to sample JSON.")
|
||||
validators_run.add_argument(
|
||||
"--name",
|
||||
action="append",
|
||||
help="Validator name to run (repeatable).",
|
||||
)
|
||||
validators_run.set_defaults(func=_validators_run_command)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -92,6 +150,9 @@ def main(argv: Iterable[str] | None = None) -> int:
|
||||
if args.command == "resources" and not getattr(args, "resources_command", None):
|
||||
parser.print_help()
|
||||
return 1
|
||||
if args.command == "validators" and not getattr(args, "validators_command", None):
|
||||
parser.print_help()
|
||||
return 1
|
||||
return args.func(args)
|
||||
|
||||
|
||||
|
||||
73
src/afs_scawful/training.py
Normal file
73
src/afs_scawful/training.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Training sample data models for AFS Scawful."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingSample:
|
||||
instruction: str
|
||||
input: str
|
||||
output: str
|
||||
domain: str
|
||||
source: str = ""
|
||||
sample_id: str = ""
|
||||
timestamp: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
kg_entities: list[str] = field(default_factory=list)
|
||||
kg_validated: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.sample_id:
|
||||
self.sample_id = str(uuid.uuid4())
|
||||
if not self.timestamp:
|
||||
self.timestamp = datetime.now().isoformat()
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"instruction": self.instruction,
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"domain": self.domain,
|
||||
"source": self.source,
|
||||
"sample_id": self.sample_id,
|
||||
"timestamp": self.timestamp,
|
||||
"metadata": self.metadata,
|
||||
"kg_entities": self.kg_entities,
|
||||
"kg_validated": self.kg_validated,
|
||||
}
|
||||
|
||||
def to_jsonl_entry(self) -> str:
|
||||
payload = {
|
||||
"instruction": self.instruction,
|
||||
"output": self.output,
|
||||
}
|
||||
if self.input:
|
||||
payload["input"] = self.input
|
||||
payload["_metadata"] = {
|
||||
"sample_id": self.sample_id,
|
||||
"domain": self.domain,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
return json.dumps(payload)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TrainingSample":
|
||||
return cls(
|
||||
instruction=data.get("instruction", ""),
|
||||
input=data.get("input", ""),
|
||||
output=data.get("output", ""),
|
||||
domain=data.get("domain", ""),
|
||||
source=data.get("source", ""),
|
||||
sample_id=data.get("sample_id", ""),
|
||||
timestamp=data.get("timestamp", ""),
|
||||
metadata=data.get("metadata", {}) or {},
|
||||
kg_entities=data.get("kg_entities", []) or [],
|
||||
kg_validated=bool(data.get("kg_validated", False)),
|
||||
)
|
||||
27
src/afs_scawful/validators/__init__.py
Normal file
27
src/afs_scawful/validators/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Validator registry for AFS Scawful."""
|
||||
|
||||
from .asar_validator import AsarValidator
|
||||
from .asm_validator import AsmValidator
|
||||
from .base import CompositeValidator, ValidationResult, Validator
|
||||
from .cpp_validator import CppValidator
|
||||
from .kg_validator import KGValidator
|
||||
|
||||
__all__ = [
|
||||
"AsarValidator",
|
||||
"AsmValidator",
|
||||
"CppValidator",
|
||||
"CompositeValidator",
|
||||
"KGValidator",
|
||||
"ValidationResult",
|
||||
"Validator",
|
||||
"default_validators",
|
||||
]
|
||||
|
||||
|
||||
def default_validators() -> list[Validator]:
|
||||
return [
|
||||
AsmValidator(),
|
||||
AsarValidator(),
|
||||
CppValidator(),
|
||||
KGValidator(),
|
||||
]
|
||||
127
src/afs_scawful/validators/asar_validator.py
Normal file
127
src/afs_scawful/validators/asar_validator.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Asar Validator for verifying 65816 assembly code.
|
||||
|
||||
Uses the actual 'asar' binary to assemble code snippets against a dummy ROM.
|
||||
This provides 100% accurate syntax and label validation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..training import TrainingSample
|
||||
from .base import ValidationResult, Validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _resolve_env_path(env_var: str) -> Path | None:
|
||||
value = os.environ.get(env_var)
|
||||
if not value:
|
||||
return None
|
||||
return Path(value).expanduser().resolve()
|
||||
|
||||
|
||||
def _default_asar_path() -> Path:
|
||||
env = _resolve_env_path("AFS_ASAR_PATH")
|
||||
if env:
|
||||
return env
|
||||
found = shutil.which("asar")
|
||||
if found:
|
||||
return Path(found)
|
||||
return Path("asar")
|
||||
|
||||
|
||||
def _default_rom_path() -> Path:
|
||||
env = _resolve_env_path("AFS_ASAR_ROM")
|
||||
if env:
|
||||
return env
|
||||
candidate = Path.home() / "src" / "training" / "roms" / "dummy.sfc"
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return Path.home() / ".context" / "training" / "dummy.sfc"
|
||||
|
||||
|
||||
class AsarValidator(Validator):
|
||||
"""Validates assembly code by running it through Asar."""
|
||||
|
||||
def __init__(self, asar_path: Path | None = None, rom_path: Path | None = None):
|
||||
super().__init__("AsarValidator", "asm")
|
||||
self.asar_path = asar_path or _default_asar_path()
|
||||
self.rom_path = rom_path or _default_rom_path()
|
||||
|
||||
if not self.asar_path.exists():
|
||||
logger.warning("Asar binary not found at %s", self.asar_path)
|
||||
if not self.rom_path.exists():
|
||||
logger.warning("Dummy ROM not found at %s", self.rom_path)
|
||||
|
||||
async def validate(self, sample: TrainingSample) -> ValidationResult:
|
||||
"""Run asar on the sample output code."""
|
||||
if not self.asar_path.exists() or not self.rom_path.exists():
|
||||
return ValidationResult(
|
||||
valid=True,
|
||||
score=0.5,
|
||||
warnings=["Asar validator skipped: binary or ROM missing"],
|
||||
)
|
||||
|
||||
# Extract code (simple heuristic: look for code blocks or use full output)
|
||||
code = self._extract_code(sample.output)
|
||||
if not code:
|
||||
return ValidationResult(valid=False, score=0.0, errors=["No code found"])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
source_file = tmp_path / "test.asm"
|
||||
rom_file = tmp_path / "test.sfc"
|
||||
|
||||
# Copy dummy ROM to temp (to avoid modifying the original)
|
||||
shutil.copy(self.rom_path, rom_file)
|
||||
|
||||
# Wrap code in a safe patch structure
|
||||
# We assume the code is a snippet, so we hook it into free space
|
||||
wrapped_code = (
|
||||
"lorom\n"
|
||||
"org $008000\n" # Hook into start of ROM
|
||||
f"{code}\n"
|
||||
)
|
||||
|
||||
source_file.write_text(wrapped_code)
|
||||
|
||||
# Run asar
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(self.asar_path),
|
||||
str(source_file),
|
||||
str(rom_file),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode == 0:
|
||||
return ValidationResult(valid=True, score=1.0)
|
||||
else:
|
||||
error_msg = stderr.decode() + stdout.decode()
|
||||
# Clean up error message
|
||||
lines = [l for l in error_msg.split('\n') if "error:" in l.lower()]
|
||||
return ValidationResult(
|
||||
valid=False,
|
||||
score=0.0,
|
||||
errors=lines[:3] or ["Asar failed to assemble"],
|
||||
)
|
||||
|
||||
def _extract_code(self, text: str) -> str:
|
||||
"""Extract ASM code from markdown block or raw text."""
|
||||
if "```asm" in text:
|
||||
parts = text.split("```asm")
|
||||
if len(parts) > 1:
|
||||
return parts[1].split("```")[0].strip()
|
||||
if "```" in text:
|
||||
parts = text.split("```")
|
||||
if len(parts) > 1:
|
||||
return parts[1].strip()
|
||||
return text # Assume raw code if no blocks
|
||||
342
src/afs_scawful/validators/asm_validator.py
Normal file
342
src/afs_scawful/validators/asm_validator.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""ASM Validator for 65816 assembly training samples.
|
||||
|
||||
Validates:
|
||||
- Instruction mnemonics
|
||||
- Addressing modes
|
||||
- Register usage
|
||||
- Memory addressing patterns
|
||||
- SNES-specific constructs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ..training import TrainingSample
|
||||
from .base import ValidationResult, Validator
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstructionInfo:
|
||||
"""Information about a 65816 instruction."""
|
||||
|
||||
mnemonic: str
|
||||
addressing_modes: list[str]
|
||||
description: str
|
||||
|
||||
|
||||
class AsmValidator(Validator):
|
||||
"""Validator for 65816 assembly code in training samples."""
|
||||
|
||||
# Valid 65816 instruction mnemonics
|
||||
VALID_MNEMONICS = {
|
||||
# Load/Store
|
||||
"LDA", "LDX", "LDY", "STA", "STX", "STY", "STZ",
|
||||
# Transfer
|
||||
"TAX", "TAY", "TXA", "TYA", "TXS", "TSX", "TCD", "TDC", "TCS", "TSC", "TXY", "TYX",
|
||||
# Stack
|
||||
"PHA", "PHP", "PHX", "PHY", "PHB", "PHD", "PHK",
|
||||
"PLA", "PLP", "PLX", "PLY", "PLB", "PLD",
|
||||
"PEA", "PEI", "PER",
|
||||
# Arithmetic
|
||||
"ADC", "SBC", "INC", "INX", "INY", "DEC", "DEX", "DEY",
|
||||
# Comparison
|
||||
"CMP", "CPX", "CPY",
|
||||
# Logical
|
||||
"AND", "ORA", "EOR", "BIT",
|
||||
# Shift/Rotate
|
||||
"ASL", "LSR", "ROL", "ROR",
|
||||
# Branch
|
||||
"BCC", "BCS", "BEQ", "BMI", "BNE", "BPL", "BVC", "BVS", "BRA", "BRL",
|
||||
# Jump
|
||||
"JMP", "JML", "JSR", "JSL", "RTS", "RTL", "RTI",
|
||||
# Flags
|
||||
"CLC", "CLD", "CLI", "CLV", "SEC", "SED", "SEI",
|
||||
"REP", "SEP",
|
||||
# Processor
|
||||
"NOP", "WDM", "STP", "WAI", "XBA", "XCE",
|
||||
# Block Move
|
||||
"MVP", "MVN",
|
||||
# Misc
|
||||
"BRK", "COP", "WDM",
|
||||
# 65C816 specific
|
||||
"TRB", "TSB",
|
||||
}
|
||||
|
||||
# Valid addressing mode patterns
|
||||
ADDRESSING_PATTERNS = {
|
||||
"immediate_8": r"#\$[0-9A-Fa-f]{1,2}", # #$XX
|
||||
"immediate_16": r"#\$[0-9A-Fa-f]{3,4}", # #$XXXX
|
||||
"immediate_symbol": r"#[A-Za-z_]\w*", # #SYMBOL
|
||||
"direct_page": r"\$[0-9A-Fa-f]{1,2}(?!\w)", # $XX (not followed by more hex)
|
||||
"absolute": r"\$[0-9A-Fa-f]{4}(?!\w)", # $XXXX
|
||||
"long": r"\$[0-9A-Fa-f]{6}", # $XXXXXX
|
||||
"indexed_x": r",\s*[Xx]", # ,X
|
||||
"indexed_y": r",\s*[Yy]", # ,Y
|
||||
"indirect": r"\([^)]+\)", # (...)
|
||||
"stack_relative": r"\$[0-9A-Fa-f]{1,2},\s*[Ss]", # $XX,S
|
||||
"accumulator": r"[Aa](?:\s|$)", # A
|
||||
"label": r"[A-Za-z_]\w*", # Labels
|
||||
}
|
||||
|
||||
# SNES-specific registers and addresses
|
||||
SNES_REGISTERS = {
|
||||
# PPU Registers
|
||||
"INIDISP", "OBSEL", "OAMADDL", "OAMADDH", "OAMDATA",
|
||||
"BGMODE", "MOSAIC", "BG1SC", "BG2SC", "BG3SC", "BG4SC",
|
||||
"BG12NBA", "BG34NBA", "BG1HOFS", "BG1VOFS", "BG2HOFS", "BG2VOFS",
|
||||
"BG3HOFS", "BG3VOFS", "BG4HOFS", "BG4VOFS",
|
||||
"VMAIN", "VMADDL", "VMADDH", "VMDATAL", "VMDATAH",
|
||||
"M7SEL", "M7A", "M7B", "M7C", "M7D", "M7X", "M7Y",
|
||||
"CGADD", "CGDATA", "W12SEL", "W34SEL", "WOBJSEL",
|
||||
"WH0", "WH1", "WH2", "WH3", "WBGLOG", "WOBJLOG",
|
||||
"TM", "TS", "TMW", "TSW", "CGWSEL", "CGADSUB",
|
||||
"COLDATA", "SETINI",
|
||||
# APU Registers
|
||||
"APUIO0", "APUIO1", "APUIO2", "APUIO3",
|
||||
# DMA Registers
|
||||
"MDMAEN", "HDMAEN", "MEMSEL",
|
||||
# CPU Registers
|
||||
"NMITIMEN", "WRIO", "WRMPYA", "WRMPYB", "WRDIVL", "WRDIVH",
|
||||
"WRDIVB", "HTIMEL", "HTIMEH", "VTIMEL", "VTIMEH",
|
||||
"RDNMI", "TIMEUP", "HVBJOY", "RDIO", "RDDIVL", "RDDIVH",
|
||||
"RDMPYL", "RDMPYH", "JOY1L", "JOY1H", "JOY2L", "JOY2H",
|
||||
"JOY3L", "JOY3H", "JOY4L", "JOY4H",
|
||||
}
|
||||
|
||||
# Common ALTTP-specific labels
|
||||
ALTTP_LABELS = {
|
||||
"Module", "Submodule", "Link", "Player", "Sprite",
|
||||
"WRAM", "SRAM", "VRAM", "OAM", "CGRAM",
|
||||
}
|
||||
|
||||
VALID_DOMAINS = {"asm", "hack_curated"}
|
||||
|
||||
def __init__(self, strict: bool = False):
|
||||
"""Initialize ASM validator.
|
||||
|
||||
Args:
|
||||
strict: If True, apply stricter validation rules
|
||||
"""
|
||||
super().__init__("AsmValidator", "asm")
|
||||
self.strict = strict
|
||||
|
||||
def can_validate(self, sample: TrainingSample) -> bool:
|
||||
"""Allow ASM validation for curated hack samples too."""
|
||||
return sample.domain in self.VALID_DOMAINS or sample.domain.startswith("asm")
|
||||
|
||||
async def validate(self, sample: TrainingSample) -> ValidationResult:
|
||||
"""Validate 65816 assembly in the sample output."""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
details: dict = {
|
||||
"instructions_found": 0,
|
||||
"valid_instructions": 0,
|
||||
"invalid_instructions": [],
|
||||
"snes_registers_used": [],
|
||||
"addressing_modes": [],
|
||||
}
|
||||
|
||||
# Extract code from output
|
||||
code = sample.output
|
||||
|
||||
# Parse instructions
|
||||
instructions = self._extract_instructions(code)
|
||||
details["instructions_found"] = len(instructions)
|
||||
|
||||
if len(instructions) == 0:
|
||||
warnings.append("No assembly instructions found in output")
|
||||
return ValidationResult(
|
||||
valid=True,
|
||||
score=0.5,
|
||||
warnings=warnings,
|
||||
details=details,
|
||||
)
|
||||
|
||||
# Validate each instruction
|
||||
for line_num, instr in instructions:
|
||||
result = self._validate_instruction(instr)
|
||||
if result.valid:
|
||||
details["valid_instructions"] += 1
|
||||
if result.addressing_mode:
|
||||
details["addressing_modes"].append(result.addressing_mode)
|
||||
else:
|
||||
details["invalid_instructions"].append({
|
||||
"line": line_num,
|
||||
"instruction": instr,
|
||||
"error": result.error,
|
||||
})
|
||||
if self.strict:
|
||||
errors.append(f"Line {line_num}: {result.error}")
|
||||
else:
|
||||
warnings.append(f"Line {line_num}: {result.error}")
|
||||
|
||||
# Check for SNES registers
|
||||
for reg in self.SNES_REGISTERS:
|
||||
if reg in code:
|
||||
details["snes_registers_used"].append(reg)
|
||||
|
||||
# Calculate score
|
||||
if details["instructions_found"] > 0:
|
||||
score = details["valid_instructions"] / details["instructions_found"]
|
||||
else:
|
||||
score = 0.5
|
||||
|
||||
# Boost score if SNES-specific content found
|
||||
if details["snes_registers_used"]:
|
||||
score = min(1.0, score + 0.1)
|
||||
|
||||
return ValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
score=score,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
details=details,
|
||||
)
|
||||
|
||||
def _extract_instructions(self, code: str) -> list[tuple[int, str]]:
|
||||
"""Extract assembly instructions from code.
|
||||
|
||||
Returns:
|
||||
List of (line_number, instruction) tuples
|
||||
"""
|
||||
instructions = []
|
||||
lines = code.split("\n")
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
# Remove comments
|
||||
if ";" in line:
|
||||
line = line[:line.index(";")]
|
||||
|
||||
# Remove labels (lines ending with :)
|
||||
if ":" in line:
|
||||
# Check if it's a label definition
|
||||
parts = line.split(":")
|
||||
if len(parts) > 1:
|
||||
line = parts[-1]
|
||||
|
||||
# Remove address prefixes like #_008000:
|
||||
line = re.sub(r"#_[0-9A-Fa-f]+:\s*", "", line)
|
||||
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Check if line starts with a valid mnemonic
|
||||
parts = line.split()
|
||||
if parts:
|
||||
mnemonic = parts[0].upper()
|
||||
if mnemonic in self.VALID_MNEMONICS:
|
||||
instructions.append((i, line))
|
||||
elif re.match(r"[A-Za-z]{2,4}", mnemonic):
|
||||
# Might be an instruction-like thing
|
||||
instructions.append((i, line))
|
||||
|
||||
return instructions
|
||||
|
||||
def _validate_instruction(self, instruction: str) -> "_InstructionValidation":
|
||||
"""Validate a single instruction."""
|
||||
parts = instruction.split(None, 1)
|
||||
if not parts:
|
||||
return _InstructionValidation(False, "Empty instruction")
|
||||
|
||||
mnemonic = parts[0].upper()
|
||||
operand = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# Check mnemonic
|
||||
if mnemonic not in self.VALID_MNEMONICS:
|
||||
# Check if it's close to a valid mnemonic (typo detection)
|
||||
close_matches = [m for m in self.VALID_MNEMONICS
|
||||
if self._levenshtein_distance(mnemonic, m) <= 1]
|
||||
if close_matches:
|
||||
return _InstructionValidation(
|
||||
False,
|
||||
f"Unknown mnemonic '{mnemonic}' (did you mean {close_matches[0]}?)"
|
||||
)
|
||||
return _InstructionValidation(False, f"Unknown mnemonic '{mnemonic}'")
|
||||
|
||||
# Validate operand if present
|
||||
addressing_mode = None
|
||||
if operand:
|
||||
addressing_mode = self._detect_addressing_mode(operand)
|
||||
|
||||
return _InstructionValidation(True, None, addressing_mode)
|
||||
|
||||
def _detect_addressing_mode(self, operand: str) -> Optional[str]:
|
||||
"""Detect the addressing mode from the operand."""
|
||||
operand = operand.strip()
|
||||
|
||||
# Check patterns in order of specificity
|
||||
if re.match(r"#", operand):
|
||||
if re.search(r"#\$[0-9A-Fa-f]{3,4}", operand):
|
||||
return "immediate_16"
|
||||
elif re.search(r"#\$[0-9A-Fa-f]{1,2}", operand):
|
||||
return "immediate_8"
|
||||
else:
|
||||
return "immediate_symbol"
|
||||
|
||||
if re.search(r",\s*[Ss]", operand):
|
||||
return "stack_relative"
|
||||
|
||||
if re.match(r"\([^)]+\)", operand):
|
||||
if ",X" in operand.upper():
|
||||
return "indexed_indirect_x"
|
||||
elif ",Y" in operand.upper():
|
||||
return "indirect_indexed_y"
|
||||
else:
|
||||
return "indirect"
|
||||
|
||||
if re.search(r",\s*[Xx]", operand):
|
||||
return "indexed_x"
|
||||
|
||||
if re.search(r",\s*[Yy]", operand):
|
||||
return "indexed_y"
|
||||
|
||||
if re.match(r"\$[0-9A-Fa-f]{6}", operand):
|
||||
return "long"
|
||||
|
||||
if re.match(r"\$[0-9A-Fa-f]{4}", operand):
|
||||
return "absolute"
|
||||
|
||||
if re.match(r"\$[0-9A-Fa-f]{1,2}(?!\w)", operand):
|
||||
return "direct_page"
|
||||
|
||||
if re.match(r"[Aa]$", operand):
|
||||
return "accumulator"
|
||||
|
||||
if re.match(r"[A-Za-z_]\w*", operand):
|
||||
return "label"
|
||||
|
||||
return None
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
"""Calculate Levenshtein distance between two strings."""
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InstructionValidation:
|
||||
"""Internal result of validating a single instruction."""
|
||||
|
||||
valid: bool
|
||||
error: Optional[str] = None
|
||||
addressing_mode: Optional[str] = None
|
||||
90
src/afs_scawful/validators/base.py
Normal file
90
src/afs_scawful/validators/base.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Base validator interfaces for AFS Scawful."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..training import TrainingSample
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
valid: bool
|
||||
score: float
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"score": self.score,
|
||||
"errors": list(self.errors),
|
||||
"warnings": list(self.warnings),
|
||||
"details": dict(self.details),
|
||||
}
|
||||
|
||||
|
||||
class Validator(ABC):
|
||||
def __init__(self, name: str, domain: str) -> None:
|
||||
self.name = name
|
||||
self.domain = domain
|
||||
|
||||
@abstractmethod
|
||||
async def validate(self, sample: TrainingSample) -> ValidationResult:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_validate(self, sample: TrainingSample) -> bool:
|
||||
return sample.domain == self.domain
|
||||
|
||||
async def validate_batch(self, samples: list[TrainingSample]) -> list[ValidationResult]:
|
||||
results: list[ValidationResult] = []
|
||||
for sample in samples:
|
||||
if self.can_validate(sample):
|
||||
results.append(await self.validate(sample))
|
||||
else:
|
||||
results.append(
|
||||
ValidationResult(
|
||||
valid=True,
|
||||
score=1.0,
|
||||
warnings=[f"{self.name} skipped: domain mismatch"],
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class CompositeValidator(Validator):
|
||||
def __init__(self, validators: list[Validator]) -> None:
|
||||
super().__init__("CompositeValidator", "all")
|
||||
self.validators = validators
|
||||
|
||||
def can_validate(self, sample: TrainingSample) -> bool:
|
||||
return any(validator.can_validate(sample) for validator in self.validators)
|
||||
|
||||
async def validate(self, sample: TrainingSample) -> ValidationResult:
|
||||
applicable = [v for v in self.validators if v.can_validate(sample)]
|
||||
if not applicable:
|
||||
return ValidationResult(valid=True, score=1.0, warnings=["No applicable validators"])
|
||||
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
details: dict[str, Any] = {}
|
||||
scores: list[float] = []
|
||||
|
||||
for validator in applicable:
|
||||
result = await validator.validate(sample)
|
||||
errors.extend(result.errors)
|
||||
warnings.extend(result.warnings)
|
||||
details[validator.name] = result.to_dict()
|
||||
scores.append(result.score)
|
||||
|
||||
score = sum(scores) / len(scores) if scores else 1.0
|
||||
return ValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
score=score,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
details=details,
|
||||
)
|
||||
340
src/afs_scawful/validators/cpp_validator.py
Normal file
340
src/afs_scawful/validators/cpp_validator.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""C++ Validator for training samples.
|
||||
|
||||
Validates:
|
||||
- Basic syntax checks (brackets, braces, semicolons)
|
||||
- Keyword usage
|
||||
- Common patterns
|
||||
- Optional: Compile check with clang (if available)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..training import TrainingSample
|
||||
from .base import ValidationResult, Validator
|
||||
|
||||
|
||||
class CppValidator(Validator):
|
||||
"""Validator for C++ code in training samples."""
|
||||
|
||||
# C++ keywords
|
||||
KEYWORDS = {
|
||||
# Storage class
|
||||
"auto", "register", "static", "extern", "mutable", "thread_local",
|
||||
# Type specifiers
|
||||
"void", "bool", "char", "short", "int", "long", "float", "double",
|
||||
"signed", "unsigned", "wchar_t", "char8_t", "char16_t", "char32_t",
|
||||
# Type qualifiers
|
||||
"const", "volatile", "constexpr", "consteval", "constinit",
|
||||
# Control flow
|
||||
"if", "else", "switch", "case", "default", "while", "do", "for",
|
||||
"break", "continue", "return", "goto",
|
||||
# Declarations
|
||||
"class", "struct", "union", "enum", "typedef", "using", "namespace",
|
||||
"template", "typename", "concept", "requires",
|
||||
# Access specifiers
|
||||
"public", "private", "protected",
|
||||
# Other keywords
|
||||
"virtual", "override", "final", "explicit", "inline", "friend",
|
||||
"operator", "sizeof", "alignof", "decltype", "typeid",
|
||||
"new", "delete", "this", "nullptr", "true", "false",
|
||||
"try", "catch", "throw", "noexcept",
|
||||
"static_assert", "static_cast", "dynamic_cast", "const_cast", "reinterpret_cast",
|
||||
"co_await", "co_return", "co_yield",
|
||||
# Modules (C++20)
|
||||
"module", "import", "export",
|
||||
}
|
||||
|
||||
# Common C++ standard library types
|
||||
STD_TYPES = {
|
||||
"string", "vector", "map", "unordered_map", "set", "unordered_set",
|
||||
"list", "deque", "array", "pair", "tuple", "optional", "variant",
|
||||
"shared_ptr", "unique_ptr", "weak_ptr", "function", "any",
|
||||
"thread", "mutex", "lock_guard", "unique_lock", "condition_variable",
|
||||
"future", "promise", "async", "atomic",
|
||||
"ifstream", "ofstream", "fstream", "stringstream", "ostringstream",
|
||||
"iostream", "cin", "cout", "cerr", "endl",
|
||||
"size_t", "ptrdiff_t", "nullptr_t", "byte",
|
||||
"int8_t", "int16_t", "int32_t", "int64_t",
|
||||
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_compile: bool = False,
|
||||
compiler: str = "clang++",
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Initialize C++ validator.
|
||||
|
||||
Args:
|
||||
check_compile: If True, attempt to compile the code
|
||||
compiler: Compiler to use for compile checks
|
||||
strict: If True, apply stricter validation
|
||||
"""
|
||||
super().__init__("CppValidator", "cpp")
|
||||
self.check_compile = check_compile
|
||||
self.compiler = compiler
|
||||
self.strict = strict
|
||||
|
||||
# Check if compiler is available
|
||||
self._compiler_available = shutil.which(compiler) is not None
|
||||
|
||||
async def validate(self, sample: TrainingSample) -> ValidationResult:
|
||||
"""Validate C++ code in the sample output."""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
details: dict = {
|
||||
"syntax_issues": [],
|
||||
"keywords_found": [],
|
||||
"std_types_found": [],
|
||||
"bracket_balance": True,
|
||||
"compile_checked": False,
|
||||
"compile_result": None,
|
||||
}
|
||||
|
||||
code = sample.output
|
||||
|
||||
# Basic syntax checks
|
||||
syntax_result = self._check_syntax(code)
|
||||
details["syntax_issues"] = syntax_result["issues"]
|
||||
details["bracket_balance"] = syntax_result["balanced"]
|
||||
|
||||
if not syntax_result["balanced"]:
|
||||
errors.append("Unbalanced brackets/braces/parentheses")
|
||||
|
||||
for issue in syntax_result["issues"]:
|
||||
if self.strict:
|
||||
errors.append(issue)
|
||||
else:
|
||||
warnings.append(issue)
|
||||
|
||||
# Check for keywords and types
|
||||
details["keywords_found"] = self._find_keywords(code)
|
||||
details["std_types_found"] = self._find_std_types(code)
|
||||
|
||||
# Compile check if enabled and available
|
||||
if self.check_compile and self._compiler_available:
|
||||
compile_result = await self._check_compile(code)
|
||||
details["compile_checked"] = True
|
||||
details["compile_result"] = compile_result
|
||||
|
||||
if not compile_result["success"]:
|
||||
if self.strict:
|
||||
errors.append(f"Compile error: {compile_result['error'][:200]}")
|
||||
else:
|
||||
warnings.append(f"Compile warning: {compile_result['error'][:100]}")
|
||||
|
||||
# Calculate score
|
||||
score = 1.0
|
||||
|
||||
# Deduct for syntax issues
|
||||
score -= len(details["syntax_issues"]) * 0.1
|
||||
score = max(0.0, score)
|
||||
|
||||
# Deduct for bracket imbalance
|
||||
if not details["bracket_balance"]:
|
||||
score -= 0.3
|
||||
|
||||
# Bonus for using C++ features
|
||||
if details["keywords_found"]:
|
||||
score = min(1.0, score + 0.05)
|
||||
if details["std_types_found"]:
|
||||
score = min(1.0, score + 0.05)
|
||||
|
||||
# Deduct for compile failure
|
||||
if details["compile_checked"] and not details["compile_result"]["success"]:
|
||||
score -= 0.2
|
||||
|
||||
score = max(0.0, min(1.0, score))
|
||||
|
||||
return ValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
score=score,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
details=details,
|
||||
)
|
||||
|
||||
def _check_syntax(self, code: str) -> dict:
|
||||
"""Check basic C++ syntax."""
|
||||
issues = []
|
||||
balanced = True
|
||||
|
||||
# Check bracket balance
|
||||
stack = []
|
||||
pairs = {"(": ")", "[": "]", "{": "}"}
|
||||
in_string = False
|
||||
in_char = False
|
||||
in_comment = False
|
||||
in_block_comment = False
|
||||
|
||||
i = 0
|
||||
while i < len(code):
|
||||
c = code[i]
|
||||
|
||||
# Handle comments
|
||||
if not in_string and not in_char:
|
||||
if i < len(code) - 1:
|
||||
two_char = code[i:i+2]
|
||||
if two_char == "//":
|
||||
# Skip to end of line
|
||||
while i < len(code) and code[i] != "\n":
|
||||
i += 1
|
||||
continue
|
||||
elif two_char == "/*":
|
||||
in_block_comment = True
|
||||
i += 2
|
||||
continue
|
||||
elif two_char == "*/" and in_block_comment:
|
||||
in_block_comment = False
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if in_block_comment:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle strings
|
||||
if c == '"' and not in_char and (i == 0 or code[i-1] != '\\'):
|
||||
in_string = not in_string
|
||||
elif c == "'" and not in_string and (i == 0 or code[i-1] != '\\'):
|
||||
in_char = not in_char
|
||||
|
||||
if not in_string and not in_char:
|
||||
if c in pairs:
|
||||
stack.append(c)
|
||||
elif c in pairs.values():
|
||||
if not stack:
|
||||
balanced = False
|
||||
issues.append(f"Unexpected closing bracket '{c}'")
|
||||
else:
|
||||
expected = pairs[stack.pop()]
|
||||
if c != expected:
|
||||
balanced = False
|
||||
issues.append(f"Mismatched brackets: expected '{expected}', got '{c}'")
|
||||
|
||||
i += 1
|
||||
|
||||
if stack:
|
||||
balanced = False
|
||||
issues.append(f"Unclosed brackets: {stack}")
|
||||
|
||||
# Check for common issues
|
||||
# Missing semicolons after statements (heuristic)
|
||||
lines = code.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip empty lines, comments, preprocessor
|
||||
if not stripped or stripped.startswith("//") or stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
# Skip lines that end with block characters
|
||||
if stripped.endswith("{") or stripped.endswith("}") or stripped.endswith(":"):
|
||||
continue
|
||||
|
||||
# Skip lines that are likely continuations
|
||||
if stripped.endswith(",") or stripped.endswith("\\"):
|
||||
continue
|
||||
|
||||
# Check for statements that should end with semicolon
|
||||
# This is a heuristic and may have false positives
|
||||
statement_patterns = [
|
||||
r"return\s+.+[^;]$", # return without semicolon
|
||||
r"break$", # break without semicolon
|
||||
r"continue$", # continue without semicolon
|
||||
]
|
||||
|
||||
for pattern in statement_patterns:
|
||||
if re.search(pattern, stripped):
|
||||
issues.append(f"Line {i+1}: Possibly missing semicolon")
|
||||
break
|
||||
|
||||
return {"issues": issues, "balanced": balanced}
|
||||
|
||||
def _find_keywords(self, code: str) -> list[str]:
|
||||
"""Find C++ keywords in code."""
|
||||
found = []
|
||||
# Use word boundaries to find keywords
|
||||
for keyword in self.KEYWORDS:
|
||||
if re.search(rf"\b{keyword}\b", code):
|
||||
found.append(keyword)
|
||||
return found
|
||||
|
||||
def _find_std_types(self, code: str) -> list[str]:
|
||||
"""Find standard library types in code."""
|
||||
found = []
|
||||
for type_name in self.STD_TYPES:
|
||||
# Check for std::type or just type in common contexts
|
||||
if re.search(rf"std::{type_name}\b", code) or re.search(rf"\b{type_name}<", code):
|
||||
found.append(type_name)
|
||||
return found
|
||||
|
||||
async def _check_compile(self, code: str) -> dict:
|
||||
"""Attempt to compile the code."""
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".cpp", delete=False
|
||||
) as f:
|
||||
# Add minimal includes for standalone compilation
|
||||
wrapped_code = """
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// Sample code below
|
||||
""" + code
|
||||
f.write(wrapped_code)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
# Run compiler with syntax-only check
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
self.compiler,
|
||||
"-fsyntax-only",
|
||||
"-std=c++17",
|
||||
"-Wall",
|
||||
str(temp_path),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=10.0
|
||||
)
|
||||
|
||||
success = process.returncode == 0
|
||||
error = stderr.decode("utf-8", errors="replace") if stderr else ""
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"error": error,
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Compilation timed out",
|
||||
"returncode": -1,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"returncode": -1,
|
||||
}
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
temp_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
349
src/afs_scawful/validators/kg_validator.py
Normal file
349
src/afs_scawful/validators/kg_validator.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Knowledge Graph Validator for training samples.
|
||||
|
||||
Validates:
|
||||
- Entity presence in knowledge graph
|
||||
- Relationship consistency
|
||||
- Cross-reference validity
|
||||
- Domain alignment
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..training import TrainingSample
|
||||
from .base import ValidationResult, Validator
|
||||
|
||||
|
||||
def _default_graph_path() -> Path:
|
||||
candidate = Path.home() / "src" / "context" / "memory" / "knowledge_graph.json"
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return Path.home() / ".context" / "memory" / "knowledge_graph.json"
|
||||
|
||||
|
||||
class KGValidator(Validator):
|
||||
"""Validator for knowledge graph consistency in training samples."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_path: Optional[Path] = None,
|
||||
strict: bool = False,
|
||||
min_entity_coverage: float = 0.3,
|
||||
):
|
||||
"""Initialize KG validator.
|
||||
|
||||
Args:
|
||||
graph_path: Path to knowledge graph JSON. Defaults to ~/src/context/memory/knowledge_graph.json
|
||||
(fallback: ~/.context/memory/knowledge_graph.json).
|
||||
strict: If True, apply stricter validation (missing entities are errors)
|
||||
min_entity_coverage: Minimum fraction of mentioned entities that must be in KG
|
||||
"""
|
||||
super().__init__("KGValidator", "all") # Applies to all domains
|
||||
self.graph_path = graph_path or _default_graph_path()
|
||||
self.strict = strict
|
||||
self.min_entity_coverage = min_entity_coverage
|
||||
|
||||
# Lazy load graph
|
||||
self._graph: Optional[dict] = None
|
||||
self._nodes: dict[str, Any] = {}
|
||||
self._edges: list[dict[str, Any]] = []
|
||||
self._node_names: set[str] = set()
|
||||
self._routines: set[str] = set()
|
||||
self._symbols: set[str] = set()
|
||||
|
||||
def _load_graph(self) -> None:
|
||||
"""Load knowledge graph from disk."""
|
||||
if self._graph is not None:
|
||||
return
|
||||
|
||||
if not self.graph_path.exists():
|
||||
self._graph = {"nodes": {}, "edges": []}
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(self.graph_path.read_text())
|
||||
self._graph = data
|
||||
self._nodes = data.get("nodes", {})
|
||||
self._edges = data.get("edges", [])
|
||||
|
||||
# Build lookup sets
|
||||
for node_id, node_data in self._nodes.items():
|
||||
self._node_names.add(node_id.lower())
|
||||
|
||||
# Extract name from node data
|
||||
if isinstance(node_data, dict):
|
||||
name = node_data.get("name", "")
|
||||
if name:
|
||||
self._node_names.add(name.lower())
|
||||
|
||||
# Track routines and symbols specifically
|
||||
node_type = node_data.get("type", "")
|
||||
if node_type == "routine":
|
||||
self._routines.add(name.lower())
|
||||
elif node_type == "symbol":
|
||||
self._symbols.add(name.lower())
|
||||
|
||||
except Exception:
|
||||
self._graph = {"nodes": {}, "edges": []}
|
||||
|
||||
def can_validate(self, sample: TrainingSample) -> bool:
|
||||
"""KG validator can validate any sample with kg_entities."""
|
||||
return True # Applies to all domains
|
||||
|
||||
async def validate(self, sample: TrainingSample) -> ValidationResult:
|
||||
"""Validate knowledge graph consistency in the sample."""
|
||||
self._load_graph()
|
||||
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
details: dict = {
|
||||
"entities_mentioned": [],
|
||||
"entities_found": [],
|
||||
"entities_missing": [],
|
||||
"routines_mentioned": [],
|
||||
"symbols_mentioned": [],
|
||||
"relationships_valid": True,
|
||||
"coverage": 0.0,
|
||||
}
|
||||
|
||||
# Extract entities from sample
|
||||
text = f"{sample.instruction} {sample.input} {sample.output}"
|
||||
mentioned = self._extract_entities(text, sample.domain)
|
||||
|
||||
details["entities_mentioned"] = mentioned
|
||||
|
||||
# Check which entities exist in KG
|
||||
found = []
|
||||
missing = []
|
||||
|
||||
for entity in mentioned:
|
||||
entity_lower = entity.lower()
|
||||
if self._entity_exists(entity_lower):
|
||||
found.append(entity)
|
||||
else:
|
||||
missing.append(entity)
|
||||
|
||||
details["entities_found"] = found
|
||||
details["entities_missing"] = missing
|
||||
|
||||
# Calculate coverage
|
||||
if mentioned:
|
||||
coverage = len(found) / len(mentioned)
|
||||
else:
|
||||
coverage = 1.0 # No entities to validate
|
||||
|
||||
details["coverage"] = coverage
|
||||
|
||||
# Check for routine/symbol references in ASM samples
|
||||
if sample.domain.startswith("asm"):
|
||||
routines = self._extract_routine_references(sample.output)
|
||||
symbols = self._extract_symbol_references(sample.output)
|
||||
|
||||
details["routines_mentioned"] = routines
|
||||
details["symbols_mentioned"] = symbols
|
||||
|
||||
# Check routine validity
|
||||
for routine in routines:
|
||||
if routine.lower() not in self._routines and routine.lower() not in self._node_names:
|
||||
if self.strict:
|
||||
errors.append(f"Unknown routine: {routine}")
|
||||
else:
|
||||
warnings.append(f"Routine not in KG: {routine}")
|
||||
|
||||
# Check kg_entities from sample metadata
|
||||
if sample.kg_entities:
|
||||
for entity in sample.kg_entities:
|
||||
if not self._entity_exists(entity.lower()):
|
||||
if self.strict:
|
||||
errors.append(f"Tagged entity not in KG: {entity}")
|
||||
else:
|
||||
warnings.append(f"Tagged entity not in KG: {entity}")
|
||||
|
||||
# Validate coverage threshold
|
||||
if coverage < self.min_entity_coverage and mentioned:
|
||||
msg = f"Entity coverage {coverage:.1%} below threshold {self.min_entity_coverage:.1%}"
|
||||
if self.strict:
|
||||
errors.append(msg)
|
||||
else:
|
||||
warnings.append(msg)
|
||||
|
||||
# Calculate score
|
||||
score = 1.0
|
||||
|
||||
# Base score on coverage
|
||||
score = min(1.0, coverage + 0.3) # Coverage contributes up to 0.7
|
||||
|
||||
# Bonus for having KG entities tagged
|
||||
if sample.kg_entities and sample.kg_validated:
|
||||
score = min(1.0, score + 0.1)
|
||||
|
||||
# Penalty for missing entities
|
||||
if missing:
|
||||
penalty = len(missing) * 0.05
|
||||
score = max(0.3, score - penalty)
|
||||
|
||||
return ValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
score=score,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
details=details,
|
||||
)
|
||||
|
||||
def _entity_exists(self, entity: str) -> bool:
|
||||
"""Check if an entity exists in the knowledge graph."""
|
||||
entity_lower = entity.lower()
|
||||
|
||||
# Direct match
|
||||
if entity_lower in self._node_names:
|
||||
return True
|
||||
|
||||
# Check with common prefixes
|
||||
prefixes = ["alttp:", "oracle-of-secrets:", "project:", "routine:", "symbol:"]
|
||||
for prefix in prefixes:
|
||||
if f"{prefix}{entity_lower}" in self._node_names:
|
||||
return True
|
||||
# Also check node IDs directly
|
||||
for node_id in self._nodes:
|
||||
if node_id.lower().endswith(f":{entity_lower}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_entities(self, text: str, domain: str) -> list[str]:
|
||||
"""Extract potential entity references from text."""
|
||||
entities = []
|
||||
|
||||
# Common patterns for entity references
|
||||
patterns = [
|
||||
# Code references like `EntityName` or `RoutineName`
|
||||
r'`([A-Z][a-zA-Z0-9_]+)`',
|
||||
# Capitalized terms that look like identifiers
|
||||
r'\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)\b', # CamelCase
|
||||
# Routine names (common in ASM)
|
||||
r'\b(Link_[A-Za-z0-9_]+)\b',
|
||||
r'\b(Player_[A-Za-z0-9_]+)\b',
|
||||
r'\b(Sprite_[A-Za-z0-9_]+)\b',
|
||||
r'\b(Module_[A-Za-z0-9_]+)\b',
|
||||
# Memory addresses with labels
|
||||
r'\b([A-Z][A-Za-z0-9]+_[A-Z][A-Za-z0-9]+)\b',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
entities.extend(matches)
|
||||
|
||||
# Domain-specific extraction
|
||||
if domain.startswith("asm"):
|
||||
# Extract ASM-specific references
|
||||
asm_patterns = [
|
||||
r'\b([A-Z][a-z]+_[A-Z][a-z_0-9]+)\b', # Link_HandleSword
|
||||
r'@([A-Za-z_][A-Za-z0-9_]+)', # @Labels
|
||||
]
|
||||
for pattern in asm_patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
entities.extend(matches)
|
||||
|
||||
elif domain == "cpp":
|
||||
# Extract C++ class/function names
|
||||
cpp_patterns = [
|
||||
r'\bclass\s+([A-Z][a-zA-Z0-9_]+)\b',
|
||||
r'\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)::\w+', # ClassName::method
|
||||
]
|
||||
for pattern in cpp_patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
entities.extend(matches)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen = set()
|
||||
unique = []
|
||||
for e in entities:
|
||||
if e.lower() not in seen:
|
||||
seen.add(e.lower())
|
||||
unique.append(e)
|
||||
|
||||
return unique
|
||||
|
||||
def _extract_routine_references(self, code: str) -> list[str]:
|
||||
"""Extract routine/label references from ASM code."""
|
||||
routines = []
|
||||
|
||||
# JSR/JSL targets
|
||||
jsr_pattern = r'\b(?:JSR|JSL|JMP|JML)\s+([A-Za-z_][A-Za-z0-9_]+)\b'
|
||||
matches = re.findall(jsr_pattern, code, re.IGNORECASE)
|
||||
routines.extend(matches)
|
||||
|
||||
# BRA/BRL targets
|
||||
branch_pattern = r'\b(?:BRA|BRL|BEQ|BNE|BCC|BCS|BMI|BPL)\s+([A-Za-z_][A-Za-z0-9_]+)\b'
|
||||
matches = re.findall(branch_pattern, code, re.IGNORECASE)
|
||||
routines.extend(matches)
|
||||
|
||||
return list(set(routines))
|
||||
|
||||
def _extract_symbol_references(self, code: str) -> list[str]:
|
||||
"""Extract symbol/variable references from ASM code."""
|
||||
symbols = []
|
||||
|
||||
# LDA/STA with labels
|
||||
load_store_pattern = r'\b(?:LDA|LDX|LDY|STA|STX|STY)\s+([A-Za-z_][A-Za-z0-9_]+)\b'
|
||||
matches = re.findall(load_store_pattern, code, re.IGNORECASE)
|
||||
symbols.extend(matches)
|
||||
|
||||
# Filter out common non-symbol patterns
|
||||
filtered = []
|
||||
for sym in symbols:
|
||||
# Skip if it looks like a routine name
|
||||
if sym.lower() in self._routines:
|
||||
continue
|
||||
# Skip common mnemonics that might be captured
|
||||
if sym.upper() in {'A', 'X', 'Y', 'S'}:
|
||||
continue
|
||||
filtered.append(sym)
|
||||
|
||||
return list(set(filtered))
|
||||
|
||||
def get_related_entities(self, entity: str) -> list[dict[str, Any]]:
|
||||
"""Get entities related to a given entity in the KG."""
|
||||
self._load_graph()
|
||||
|
||||
related = []
|
||||
entity_lower = entity.lower()
|
||||
|
||||
for edge in self._edges:
|
||||
source = str(edge.get("source", "")).lower()
|
||||
target = str(edge.get("target", "")).lower()
|
||||
relation = edge.get("relation", "")
|
||||
|
||||
if entity_lower in source:
|
||||
related.append({
|
||||
"entity": edge.get("target"),
|
||||
"relation": relation,
|
||||
"direction": "outgoing",
|
||||
})
|
||||
elif entity_lower in target:
|
||||
related.append({
|
||||
"entity": edge.get("source"),
|
||||
"relation": relation,
|
||||
"direction": "incoming",
|
||||
})
|
||||
|
||||
return related
|
||||
|
||||
def suggest_entities(self, partial: str, limit: int = 10) -> list[str]:
|
||||
"""Suggest entity names matching a partial string."""
|
||||
self._load_graph()
|
||||
|
||||
partial_lower = partial.lower()
|
||||
matches = []
|
||||
|
||||
for node_id in self._nodes:
|
||||
if partial_lower in node_id.lower():
|
||||
matches.append(node_id)
|
||||
if len(matches) >= limit:
|
||||
break
|
||||
|
||||
return matches
|
||||
32
tests/test_validators.py
Normal file
32
tests/test_validators.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from afs_scawful.training import TrainingSample
|
||||
from afs_scawful.validators import AsmValidator, CppValidator
|
||||
|
||||
|
||||
def test_asm_validator_basic() -> None:
|
||||
sample = TrainingSample(
|
||||
instruction="",
|
||||
input="",
|
||||
output="LDA #$01\nSTA $7E0000\n",
|
||||
domain="asm",
|
||||
source="test",
|
||||
)
|
||||
result = asyncio.run(AsmValidator().validate(sample))
|
||||
assert result.valid
|
||||
assert result.score > 0.0
|
||||
|
||||
|
||||
def test_cpp_validator_basic() -> None:
|
||||
sample = TrainingSample(
|
||||
instruction="",
|
||||
input="",
|
||||
output="int main() { return 0; }\n",
|
||||
domain="cpp",
|
||||
source="test",
|
||||
)
|
||||
result = asyncio.run(CppValidator(check_compile=False).validate(sample))
|
||||
assert result.valid
|
||||
assert result.score > 0.0
|
||||
Reference in New Issue
Block a user