#!/usr/bin/env python3 """Enumerate all states and transitions of a synthesized AIGER controller. Generic AIGER parser — reads any .aag file and produces: 1. A full transition table 2. State-to-mode mapping 3. A Graphviz DOT file for the state machine with guard condition labels Usage: python3 scripts/trace_aiger.py [output_dir] """ import sys from itertools import product from pathlib import Path from collections import defaultdict def parse_aag(path): """Parse an ASCII AIGER file, return circuit definition.""" lines = Path(path).read_text().strip().split('\n') header = lines[0].split() assert header[0] == 'aag' n_inputs = int(header[2]) n_latches = int(header[3]) n_outputs = int(header[4]) n_ands = int(header[5]) idx = 1 input_lits = [] for _ in range(n_inputs): input_lits.append(int(lines[idx])); idx += 1 latch_pairs = [] # (current_lit, next_lit) for _ in range(n_latches): parts = lines[idx].split() latch_pairs.append((int(parts[0]), int(parts[1]))); idx += 1 output_lits = [] for _ in range(n_outputs): output_lits.append(int(lines[idx])); idx += 1 ands = [] for _ in range(n_ands): parts = lines[idx].split() ands.append((int(parts[0]), int(parts[1]), int(parts[2]))); idx += 1 input_names = {} output_names = {} while idx < len(lines): line = lines[idx] if line.startswith('i'): parts = line.split(' ', 1) input_names[int(parts[0][1:])] = parts[1] elif line.startswith('o'): parts = line.split(' ', 1) output_names[int(parts[0][1:])] = parts[1] elif line.startswith('c'): break idx += 1 return { 'input_lits': input_lits, 'input_names': [input_names.get(i, f'i{i}') for i in range(n_inputs)], 'latch_lits': [lp[0] for lp in latch_pairs], 'latch_next_lits': [lp[1] for lp in latch_pairs], 'output_lits': output_lits, 'output_names': [output_names.get(i, f'o{i}') for i in range(n_outputs)], 'ands': ands, } def eval_circuit(circ, latch_vals, input_vals): """Evaluate circuit for given latch and input values. Returns (outputs, next_latches).""" val = {0: 0, 1: 1} for lit, v in zip(circ['input_lits'], input_vals): val[lit] = v; val[lit ^ 1] = 1 - v for lit, v in zip(circ['latch_lits'], latch_vals): val[lit] = v; val[lit ^ 1] = 1 - v for lhs, rhs0, rhs1 in circ['ands']: val[lhs] = val[rhs0] & val[rhs1] val[lhs ^ 1] = 1 - val[lhs] outputs = {name: val[lit] for name, lit in zip(circ['output_names'], circ['output_lits'])} next_latches = tuple(val[nl] for nl in circ['latch_next_lits']) return outputs, next_latches def mode_label(outputs): active = [name for name, val in outputs.items() if val] return "+".join(active) if active else "NONE" def output_display_name(name): """Derive a human-readable state label from an output variable name. Strips common prefixes like 'in_' and uppercases the result. E.g. 'in_shutdown' -> 'SHUTDOWN', 'is_active' -> 'ACTIVE'. """ for prefix in ('in_', 'is_', 'at_'): if name.startswith(prefix): return name[len(prefix):].upper() return name.upper() def extract_guard(input_combos, input_names): """Build a minimal boolean guard expression from a set of input combinations. For each input variable, check if it is fixed to 0, fixed to 1, or don't-care across all combos that trigger this edge. Fixed variables become guard terms. If multiple distinct fixed-variable patterns exist, they are OR'd together. Returns a human-readable string like "inv1_holds & !p_above_crit" or "always". """ n_inputs = len(input_names) n_total = 2 ** n_inputs combos = list(input_combos) if len(combos) == n_total: return "always" if len(combos) == 0: return "never" # For each input variable, determine if it's fixed or don't-care fixed = {} # var_index -> value (0 or 1), only if fixed across ALL combos for i in range(n_inputs): vals = set(c[i] for c in combos) if len(vals) == 1: fixed[i] = vals.pop() # Check if the fixed variables alone fully explain the combo set. # Count how many combos the fixed vars predict: 2^(number of don't-care vars) n_dontcare = n_inputs - len(fixed) predicted = 2 ** n_dontcare if predicted == len(combos): # The fixed variables perfectly partition this edge — single conjunction if not fixed: return "always" terms = [] for i in sorted(fixed): name = input_names[i] if fixed[i] == 1: terms.append(name) else: terms.append(f"!{name}") return " & ".join(terms) # Multiple patterns needed — group combos by their fixed-variable signature # Strategy: iteratively find the largest single-conjunction cube that covers # a subset of remaining combos, and OR the cubes together. combo_set = set(combos) remaining = set(combos) all_possible = list(product([0, 1], repeat=n_inputs)) cubes = [] while remaining: # Pick an arbitrary combo and try to build the largest cube containing it sample = next(iter(remaining)) # Greedy: start with all vars fixed to sample values, relax one at a time. # A cube is valid only if every combo it covers belongs to this edge. cur_fixed = {i: sample[i] for i in range(n_inputs)} cur_set = {sample} changed = True while changed: changed = False for i in list(cur_fixed): # Try relaxing variable i trial_fixed = {k: v for k, v in cur_fixed.items() if k != i} trial_set = set() for c in all_possible: if all(c[k] == v for k, v in trial_fixed.items()): trial_set.add(c) # Only relax if the cube stays entirely within the edge's combos if trial_set <= combo_set and len(trial_set) > len(cur_set): cur_fixed = trial_fixed cur_set = trial_set changed = True # Build terms for this cube terms = [] for i in sorted(cur_fixed): name = input_names[i] if cur_fixed[i] == 1: terms.append(name) else: terms.append(f"!{name}") cubes.append(" & ".join(terms) if terms else "always") remaining -= cur_set if len(cubes) == 1: return cubes[0] # Wrap each multi-term cube in parens when OR-ing parts = [] for c in cubes: if " & " in c: parts.append(f"({c})") else: parts.append(c) return " | ".join(parts) def state_color(state, init_latches, reachable, mode_names): """Assign a fill color based on state semantics. Heuristics (checked in order): - Initial state: light blue - Any active output containing 'scram' or 'emergency': salmon/red - Any active output containing 'shutdown' or 'trip': light coral / orange - Any active output containing 'operation' or 'run' or 'power': light green - Otherwise (transitory / heatup / startup): light yellow - Unreachable: gray """ if state not in reachable: return "gray90" # Lowercase set of active mode keywords modes_lower = " ".join(mode_names).lower() if state == init_latches: return "\"#A8D8EA\"" # light blue if any(kw in modes_lower for kw in ('scram', 'emergency')): return "\"#FF6B6B\"" # red if any(kw in modes_lower for kw in ('shutdown', 'trip')): return "\"#FFB347\"" # orange if any(kw in modes_lower for kw in ('operation', 'run', 'power', 'normal')): return "\"#77DD77\"" # green # Transitory / heatup / startup / other return "\"#FFFACD\"" # light yellow def edge_color(src, dst, dst_modes): """Assign edge color based on transition type.""" modes_lower = " ".join(dst_modes).lower() if src == dst: return "\"#4A90D9\"", "bold" # blue for self-loops if any(kw in modes_lower for kw in ('scram', 'emergency')): return "\"#CC0000\"", "bold" # red for scram transitions return "\"#228B22\"", "" # green for normal transitions def main(): aag_path = sys.argv[1] if len(sys.argv) > 1 else "circuits/PWR_Hybrid_DRC.aag" out_dir = Path(sys.argv[2]) if len(sys.argv) > 2 else Path("diagrams") out_dir.mkdir(parents=True, exist_ok=True) circ = parse_aag(aag_path) n_latches = len(circ['latch_lits']) n_inputs = len(circ['input_lits']) n_total_inputs = 2 ** n_inputs basename = Path(aag_path).stem print("=" * 100) print(f"Synthesized Controller Trace: {basename}") print(f" Inputs ({n_inputs}): {circ['input_names']}") print(f" Latches: {n_latches}") print(f" Outputs ({len(circ['output_names'])}): {circ['output_names']}") print("=" * 100) init_latches = tuple(0 for _ in range(n_latches)) init_inputs = tuple(0 for _ in range(n_inputs)) init_out, init_ns = eval_circuit(circ, init_latches, init_inputs) print(f"\nInitial state: latches={init_latches}") print(f"Initial outputs (all inputs 0): {mode_label(init_out)}") for k, v in init_out.items(): print(f" {k}={v}", end="") print("\n") # Build header inp_hdr = " ".join(f"{n[:5]:>5}" for n in circ['input_names']) out_hdr = " ".join(f"{n[:6]:>6}" for n in circ['output_names']) l_hdr = " ".join(f"L{i}" for i in range(n_latches)) nl_hdr = " ".join(f"nL{i}" for i in range(n_latches)) header = f"{l_hdr} | {inp_hdr} | {out_hdr} | {nl_hdr} | mode" print(header) print("-" * len(header)) # transitions[src][dst] = set of input combo tuples transitions = defaultdict(lambda: defaultdict(set)) state_modes = defaultdict(set) all_latch_states = list(product([0, 1], repeat=n_latches)) all_input_combos = list(product([0, 1], repeat=n_inputs)) for ls in all_latch_states: for iv in all_input_combos: outputs, ns = eval_circuit(circ, ls, iv) ml = mode_label(outputs) state_modes[ls].add(ml) transitions[ls][ns].add(iv) l_str = " ".join(f"{v:>2}" for v in ls) i_str = " ".join(f"{v:>5}" for v in iv) o_str = " ".join(f"{v:>6}" for v in outputs.values()) n_str = " ".join(f"{v:>3}" for v in ns) print(f"{l_str} | {i_str} | {o_str} | {n_str} | {ml}") # Reachability print("\n" + "=" * 80) print(f"REACHABILITY (from initial state {init_latches})") print("=" * 80) frontier = {init_latches} reachable = set() while frontier: s = frontier.pop() if s in reachable: continue reachable.add(s) for ns in transitions[s]: if ns not in reachable: frontier.add(ns) for s in sorted(all_latch_states): modes = sorted(state_modes[s]) tag = "REACHABLE" if s in reachable else "UNREACHABLE" print(f" State {s}: modes={modes} [{tag}]") # Transition summary with guard labels print("\n" + "=" * 80) print("STATE TRANSITION SUMMARY") print("=" * 80) for s in sorted(reachable): print(f"\nFrom state {s} — modes: {sorted(state_modes[s])}") for ns in sorted(transitions[s]): guard = extract_guard(transitions[s][ns], circ['input_names']) n = len(transitions[s][ns]) print(f" -> {ns} [{n}/{n_total_inputs} combos] guard: {guard}") # Build output display name mapping out_display = {name: output_display_name(name) for name in circ['output_names']} # Generate DOT dot = ['digraph Controller {', ' rankdir=LR;', ' bgcolor="white";', ' node [shape=Mrecord, style="filled,rounded", fontname="Helvetica-Bold", fontsize=12, penwidth=1.5];', ' edge [fontname="Helvetica", fontsize=9];', ''] for s in sorted(all_latch_states): modes = sorted(state_modes[s]) # Build display label: show human-readable output names display_modes = [] for m in modes: if m == "NONE": display_modes.append("NONE") else: parts = m.split("+") display_modes.append(" + ".join(out_display.get(p, p.upper()) for p in parts)) label = " | ".join(display_modes) sid = "s" + "".join(str(v) for v in s) color = state_color(s, init_latches, reachable, modes) dot.append(f' {sid} [label="{label}", fillcolor={color}];') init_id = "s" + "".join(str(v) for v in init_latches) dot += ['', ' init [shape=point, width=0.25, color="black"];', f' init -> {init_id} [penwidth=2.0];', ''] for s in sorted(reachable): sid = "s" + "".join(str(v) for v in s) for ns in sorted(transitions[s]): nsid = "s" + "".join(str(v) for v in ns) guard = extract_guard(transitions[s][ns], circ['input_names']) # Determine destination modes for coloring dst_modes = sorted(state_modes[ns]) color, style = edge_color(s, ns, dst_modes) n = len(transitions[s][ns]) pw = "2.0" if n > n_total_inputs // 2 else "1.2" # Escape guard for DOT label dot_guard = guard.replace('"', '\\"') style_attr = f', style="{style}"' if style else '' dot.append(f' {sid} -> {nsid} [label="{dot_guard}", penwidth={pw}, color={color}, fontcolor={color}{style_attr}];') dot.append('}') dot_path = out_dir / f"{basename}_states.dot" dot_path.write_text('\n'.join(dot)) print(f"\nDOT written to: {dot_path}") if __name__ == "__main__": main()