Source code for verf.XSAT

from __future__ import annotations

import copy
import csv
import os
import time
from typing import List, Dict, Any, Tuple

import pexpect
from networkx import MultiGraph
from z3 import Bool

from utils import config
from utils.VerilogFix import VerilogFix
from core.BooleanFunctionCollection import BooleanFunctionCollection
from core.BooleanFunctionInterface import BooleanFunctionInterface
from core.benchmarks.Benchmark import VerilogBenchmark
from core.hardware.Crossbar import Crossbar, MemristorCrossbar
from core.expressions.BooleanExpression import LITERAL
from verf.EquivalenceChecker import EquivalenceChecker


[docs] class SubProblem: def __init__(self, subgraph: MultiGraph, literals: List[LITERAL]): self.subgraph = subgraph self.literals = literals
[docs] class XSAT(EquivalenceChecker): def __init__(self, crossbar: Crossbar, specification: BooleanFunctionInterface, node_threshold: int = 250, depth_threshold: int = 5): super().__init__(crossbar, specification) config.log.add("Verification method: XSAT\n") self.node_threshold = node_threshold self.depth_threshold = depth_threshold self.xbar_filepath = config.abc_path.joinpath("xbar.v") self.spec_filepath = config.abc_path.joinpath("spec.v") self.sat_info_filepath = config.abc_path.joinpath("sat.csv") self.model_filepath = config.abc_path.joinpath("model.txt") if self.sat_info_filepath.exists(): os.remove(self.sat_info_filepath) if self.model_filepath.exists(): os.remove(self.model_filepath) self.divide_and_conquer_time = None self.start_time = None self.end_time = None self._crossbar_log = dict() self._iterations_log = [] self._subproblems_log = [] self.equivalent = None @staticmethod def _get_edge_properties(graph: MultiGraph) -> Dict[str, int]: memristor_stats = {"on": 0, "off": 0, "lit": 0} for _, _, _, edge_data in graph.edges(keys=True, data=True): atom = edge_data.get("atom") positive = edge_data.get("positive") literal = LITERAL(atom, positive) if literal == LITERAL("False", False): memristor_stats["off"] += 1 elif literal == LITERAL("True", True): memristor_stats["on"] += 1 else: memristor_stats["lit"] += 1 return memristor_stats def _get_variable_count(self, xbar: MemristorCrossbar) -> Dict[str, int]: variable_count = dict() for r in range(xbar.rows): for c in range(xbar.columns): literal = xbar.get_memristor(r, c).literal if literal == LITERAL("False", False): continue elif literal == LITERAL("True", True): continue else: atom = literal.atom if atom not in variable_count: variable_count[atom] = 0 variable_count[atom] += 1 # We sort the variables in descending order in terms of their values (occurrences) variable_count = dict(sorted(variable_count.items(), key=lambda item: item[1], reverse=True)) self._log_variable_count(variable_count) return variable_count @staticmethod def _graph_with_fixed_variable(graph: MultiGraph, literal: LITERAL) -> MultiGraph: graph = MultiGraph(copy.deepcopy(graph)) for (u, v, k) in graph.edges: edge_data = graph.get_edge_data(u, v, k) e_atom = edge_data.get("atom") e_positive = edge_data.get("positive") if e_atom == literal.atom: if e_positive == literal.positive: graph.edges[(u, v, k)]["atom"] = "True" graph.edges[(u, v, k)]["positive"] = True else: graph.edges[(u, v, k)]["atom"] = "False" graph.edges[(u, v, k)]["positive"] = False return graph @staticmethod def _find_true_edge(graph: MultiGraph) -> Tuple[Any, Any, Any] | None: for (u, v, k) in graph.edges: edge_data = graph.get_edge_data(u, v, k) atom = edge_data.get("atom") positive = edge_data.get("positive") if LITERAL(atom, positive) == LITERAL("True", True): return u, v, k return None def _contract_graph(self, graph: MultiGraph) -> MultiGraph: graph = MultiGraph(graph.copy(as_view=False)) # First, we remove all edges with FALSE as literal false_edges = set() for (u, v, k) in graph.edges: edge_data = graph.get_edge_data(u, v, k) atom = edge_data.get("atom") positive = edge_data.get("positive") if LITERAL(atom, positive) == LITERAL("False", False): false_edges.add((u, v, k)) graph.remove_edges_from(false_edges) # Second, we contract all nodes that have TRUE as literal of an incident edge true_edge = self._find_true_edge(graph) while true_edge is not None: u, v, _ = true_edge victim = v # This node will be removed victim_edges = graph.edges(victim, keys=True) # We remove all edges between u and v from the victim edges for x, y, k in victim_edges: # We ignore all edges that are between u and vas they will be removed by removing the victim node if u == x and v == y or u == y and v == x: continue # This edge if between the victim and another node. Either x is the victim v, or y is the victim v. edge_data = graph.get_edge_data(x, y, k) if v == x: l = 0 while (u, y, l) in graph.edges: l += 1 graph.add_edge(u, y, l, **edge_data) else: l = 0 while (u, x, l) in graph.edges: l += 1 graph.add_edge(u, x, k, **edge_data) # We update the node data node_data = graph.nodes[victim] if "input_function" in node_data: graph.nodes[u]["input_function"] = node_data.get("input_function") if "output_functions" in node_data: if "output_functions" in graph.nodes[u]: graph.nodes[u]["output_functions"].update(node_data.get("output_functions")) else: graph.nodes[u]["output_functions"] = node_data.get("output_functions") # We remove the victim graph.remove_node(victim) true_edge = self._find_true_edge(graph) return graph def _log_iteration(self, step: str, graph: MultiGraph, total_time: float = None): iteration_log = dict() iteration_log["step"] = step iteration_log["graph"] = dict() iteration_log["graph"]["nodes"] = len(graph.nodes) iteration_log["graph"]["edges"] = self._get_edge_properties(graph) if total_time is not None: iteration_log["total_time"] = total_time self._iterations_log.append(iteration_log) def _divide_and_conquer(self, xbar: MemristorCrossbar) -> List[SubProblem]: variable_count = self._get_variable_count(xbar) variable_order = [(k, v) for k, v in variable_count.items()] graph = MultiGraph(xbar.graph()) self._log_iteration("init", graph) print("\tStarted graph contraction") start_contract_time = time.time() graph = self._contract_graph(graph) end_contract_time = time.time() self._log_iteration("contract", graph, end_contract_time - start_contract_time) print("\tStopped graph contraction") final_subproblems = [] subproblems = [SubProblem(graph, [])] while len(subproblems) != 0: subproblem = subproblems.pop(0) subgraph = subproblem.subgraph literals = subproblem.literals nodes = len(subgraph.nodes) # First, we check whether the graph has become smaller than the node threshold if self.node_threshold is None or nodes <= self.node_threshold: final_subproblems.append(subproblem) continue # Second, we look at the next variable and its count # We compare its count with the variable threshold variable, count = variable_order[len(literals)] if len(subproblem.literals) >= self.depth_threshold: final_subproblems.append(subproblem) continue positive_literal = LITERAL(variable, True) negative_literal = LITERAL(variable, False) print("\tStarted fixating variable") start_fixate_time = time.time() positive_subgraph = self._graph_with_fixed_variable(subgraph, positive_literal) end_fixate_time = time.time() self._log_iteration("fix", positive_subgraph, end_fixate_time - start_fixate_time) start_fixate_time = time.time() negative_subgraph = self._graph_with_fixed_variable(subgraph, negative_literal) end_fixate_time = time.time() self._log_iteration("fix", negative_subgraph, end_fixate_time - start_fixate_time) print("\tStopped fixating variable") print("\tStarted graph contraction") start_contract_time = time.time() positive_subgraph = self._contract_graph(positive_subgraph) end_contract_time = time.time() self._log_iteration("contract", positive_subgraph, end_contract_time - start_contract_time) start_contract_time = time.time() negative_subgraph = self._contract_graph(negative_subgraph) end_contract_time = time.time() self._log_iteration("contract", negative_subgraph, end_contract_time - start_contract_time) print("\tStopped graph contraction") subproblems.append(SubProblem(positive_subgraph, literals + [positive_literal])) subproblems.append(SubProblem(negative_subgraph, literals + [negative_literal])) return final_subproblems def _write_graph(self, subproblem: SubProblem): input_variables_str = ", ".join(list(self.specification.get_input_variables())) output_variables_str = ", ".join(list(self.specification.get_output_variables())) subgraph = subproblem.subgraph T = len(subgraph.nodes) content = "" content += "module test (\n" content += "\t{} );\n".format(input_variables_str) content += "\tinput {};\n".format(input_variables_str) content += "\toutput {};\n".format(output_variables_str) wires = [] for v in subgraph.nodes: for t in range(0, T): wires.append("{}_{}".format(v, t)) content += "\twire {};\n".format(", ".join(wires)) # Input nanowires input_wire_nodes = set() for v, node_data in subgraph.nodes(data=True): if "input_function" in node_data: input_wire_nodes.add(v) for t in range(T): content += "\tassign {}_{} = 1'b1;\n".format(v, t) # Time step t > 0 for v in subgraph.nodes: if v in input_wire_nodes: continue for t in range(1, T): conjunctions = [] for u in subgraph.neighbors(v): subdisjunction = [] for (x, y, k) in subgraph.edges(v, keys=True): if x == v and y == u: edge_data = subgraph.get_edge_data(u, v, k) atom = edge_data.get("atom") positive = edge_data.get("positive") literal = LITERAL(atom, positive) subdisjunction.append(str(literal)) conjunctions.append("({}) & {}_{}".format(" | ".join(subdisjunction), u, t - 1)) if len(conjunctions) == 0: continue content += "\tassign {}_{} = {};\n".format(v, t, " | ".join(conjunctions)) q = dict() # Output nanowires for v, node_data in subgraph.nodes(data=True): if "output_functions" in node_data: output_variables = node_data.get("output_functions") for output_variable in output_variables: if output_variable not in q: q[output_variable] = [] q[output_variable].append("(({}))".format(" | ".join(["{}_{}".format(v, t) for t in range(0, T)]))) for k, v in q.items(): content += "\tassign {} = {};\n".format(k, " | ".join(v)) content += "endmodule" with open(self.xbar_filepath, 'w') as f: f.write(content) def _log_crossbar(self, xbar: MemristorCrossbar): R = xbar.rows C = xbar.columns self._crossbar_log["rows"] = R self._crossbar_log["columns"] = C self._crossbar_log["time_steps"] = 2 * min(R - 1, C) + 1 def _log_variable_count(self, variable_count: Dict[str, int]): self._crossbar_log["variable_count"] = variable_count def _log_subproblem(self, subproblem_log: Dict[str, Any], subproblem: SubProblem): subproblem_log["graph"] = dict() subproblem_log["graph"]["nodes"] = len(subproblem.subgraph.nodes) subproblem_log["graph"]["edges"] = self._get_edge_properties(subproblem.subgraph) subproblem_log["fixed_literals"] = list(map(lambda lit: str(lit), subproblem.literals)) subproblem_log["time_steps"] = len(subproblem.subgraph.nodes) def _log_sat(self, subproblem_log: Dict[str, Any]): subproblem_log["sat"] = dict() time.sleep(3) if not self.sat_info_filepath.exists(): return with open(self.sat_info_filepath, 'r') as f: csv_file = csv.reader(f, delimiter='\t') i = 0 for line in csv_file: if i == 0: subproblem_log["sat"]["clauses"] = int(line[1]) subproblem_log["sat"]["literals"] = int(line[2]) break def _write_specification(self, specification: BooleanFunctionCollection, subproblem: SubProblem): literals = subproblem.literals for boolean_function in specification.boolean_functions: verilog_benchmark = boolean_function.to_verilog() verilog_fix = VerilogFix(verilog_benchmark, self.spec_filepath) verilog_fix.fix(literals) # TODO: The code below contains a bug somewhere. Hence, we use VerilogFix. However, in the future, # we would want to use the method fix(). verilog = verilog_benchmark.verilog # # new_functions = dict() # for output_function, formula in verilog.functions.items(): # expression = formula.verilog.boolean_expression # for literal in literals: # expression = expression.fix(literal.atom, literal.positive) # new_verilog_formula = VerilogFormula() # new_verilog_formula.output = output_function # new_verilog_formula.boolean_expression = expression # new_formula = Formula(new_verilog_formula) # new_functions[output_function] = new_formula # # verilog.functions = new_functions # subspec = VerilogBenchmark(verilog) # subspec.write(self.spec_filepath)
[docs] def is_equivalent(self, sampling_size: int = 0) -> bool: print("XSAT started") self.start_time = time.time() assert isinstance(self.specification, BooleanFunctionCollection) for boolean_function in self.specification.boolean_functions: assert isinstance(boolean_function, VerilogBenchmark) assert isinstance(self.boolean_function, MemristorCrossbar) xbar = self.boolean_function self._log_crossbar(xbar) start_divide_and_conquer_time = time.time() subproblems = self._divide_and_conquer(xbar) end_divide_and_conquer_time = time.time() self.divide_and_conquer_time = end_divide_and_conquer_time - start_divide_and_conquer_time for subproblem in subproblems: print("\tSubproblem with fixed literals: {}".format(subproblem.literals)) if self.sat_info_filepath.exists(): os.remove(self.sat_info_filepath) subproblem_log = dict() self._log_subproblem(subproblem_log, subproblem) sub_specification = copy.deepcopy(self.specification) start_write_time = time.time() self._write_specification(sub_specification, subproblem) self._write_graph(subproblem) end_write_time = time.time() subproblem_log["write_time"] = end_write_time - start_write_time # Based on https://stackoverflow.com/questions/20061176/python-wait-and-check-if-file-is-created-completely-by-external-program while True: try: with open(self.xbar_filepath, 'rb') as _: pass with open(self.spec_filepath, 'rb') as _: break except IOError: time.sleep(3) print("\tCrossbar and specification written to file.") start_cec_time = time.time() # Start a process for the ABC tool process = pexpect.spawn(config.abc_cmd, cwd=str(config.abc_path), timeout=14400) process.sendline("cec {} {};".format(self.spec_filepath.name, self.xbar_filepath.name)) index = process.expect(['.*Networks are equivalent.*', '.*Networks are NOT EQUIVALENT.*']) end_cec_time = time.time() self._log_sat(subproblem_log) subproblem_log["cec_time"] = end_cec_time - start_cec_time self.equivalent = index == 0 self._subproblems_log.append(subproblem_log) if not self.equivalent: self.end_time = time.time() config.log.add_json(self.get_log()) print("NOT equivalent") return self.equivalent self.end_time = time.time() config.log.add_json(self.get_log()) print("Equivalent") print("XSAT stopped") return self.equivalent
[docs] def to_formula(self, output_variable: str) -> Bool: raise NotImplementedError()
[docs] def get_log(self) -> Dict[str, Any]: return { "type": self.__class__.__name__, "depth_threshold": self.depth_threshold, "node_threshold": self.node_threshold, "crossbar": self._crossbar_log, "total_divide_and_conquer_time": self.divide_and_conquer_time, "iterations": self._iterations_log, "subproblems": self._subproblems_log, "equivalent": self.equivalent, "total_time": self.end_time - self.start_time }