Source code for pyreason.pyreason

# This is the file that will be imported when "import pyreason" is called. All content will be run automatically
# ruff: noqa: F401 (Ignore Pyreason import * for public api)
import importlib
import json
import networkx as nx
import numba
import time
import sys
import pandas as pd
import memory_profiler as mp
import warnings
from typing import List, Type, Callable, Tuple, Optional

from pyreason.scripts.utils.output import Output
from pyreason.scripts.utils.filter import Filter
from pyreason.scripts.program.program import Program
from pyreason.scripts.utils.graphml_parser import GraphmlParser
import pyreason.scripts.utils.yaml_parser as yaml_parser
import pyreason.scripts.utils.rule_parser as rule_parser
import pyreason.scripts.utils.filter_ruleset as ruleset_filter
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.rules.rule import Rule
from pyreason.scripts.threshold.threshold import Threshold
from pyreason.scripts.query.query import Query
import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
from pyreason.scripts.interpretation.interpretation_parallel import Interpretation
from pyreason.scripts.utils.reorder_clauses import reorder_clauses
if importlib.util.find_spec("torch") is not None:
    from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier
    from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
else:
[docs] LogicIntegratedClassifier = None
ModelInterfaceOptions = None print('torch is not installed, model integration is disabled') # USER VARIABLES class _Settings: def __init__(self): self.__verbose = None self.__output_to_file = None self.__output_file_name = None self.__graph_attribute_parsing = None self.__abort_on_inconsistency = None self.__memory_profile = None self.__reverse_digraph = None self.__atom_trace = None self.__save_graph_attributes_to_trace = None self.__canonical = None self.__persistent = None self.__inconsistency_check = None self.__static_graph_facts = None self.__store_interpretation_changes = None self.__parallel_computing = None self.__update_mode = None self.__allow_ground_rules = None self.__fp_version = None self.reset() def reset(self): self.__verbose = True self.__output_to_file = False self.__output_file_name = 'pyreason_output' self.__graph_attribute_parsing = True self.__abort_on_inconsistency = False self.__memory_profile = False self.__reverse_digraph = False self.__atom_trace = False self.__save_graph_attributes_to_trace = False self.__canonical = False self.__persistent = False self.__inconsistency_check = True self.__static_graph_facts = True self.__store_interpretation_changes = True self.__parallel_computing = False self.__update_mode = 'intersection' self.__allow_ground_rules = False self.__fp_version = False @property def verbose(self) -> bool: """Returns whether verbose mode is on or not. Default is True :return: bool """ return self.__verbose @property def output_to_file(self) -> bool: """Returns whether output is going to be printed to file or not. Default is False :return: bool """ return self.__output_to_file @property def output_file_name(self) -> str: """Returns whether name of the file output will be saved in. Only applicable if `output_to_file` is true. Default is pyreason_output :return: str """ return self.__output_file_name @property def graph_attribute_parsing(self) -> bool: """Returns whether graph will be parsed for attributes or not. Default is True :return: bool """ return self.__graph_attribute_parsing @property def abort_on_inconsistency(self) -> bool: """Returns whether program will abort when it encounters an inconsistency in the interpretation or not. Default is False :return: bool """ return self.__abort_on_inconsistency @property def memory_profile(self) -> bool: """Returns whether program will profile maximum memory usage or not. Default is False :return: bool """ return self.__memory_profile @property def reverse_digraph(self) -> bool: """Returns whether graph will be reversed or not. If graph is reversed, an edge a->b will become b->a. Default is False :return: bool """ return self.__reverse_digraph @property def atom_trace(self) -> bool: """Returns whether to keep track of all atoms that are responsible for the firing of rules or not. NOTE: Turning this on may increase memory usage. Default is False :return: bool """ return self.__atom_trace @property def save_graph_attributes_to_trace(self) -> bool: """Returns whether to save the graph attribute facts to the rule trace. Graphs are large and turning this on can result in more memory usage. NOTE: Turning this on may increase memory usage. Default is False :return: bool """ return self.__save_graph_attributes_to_trace @property def canonical(self) -> bool: """DEPRECATED, use persistent instead Returns whether the interpretation is canonical or non-canonical. Default is False :return: bool """ return self.__persistent @property def persistent(self) -> bool: """Returns whether the interpretation is persistent (Does not reset bounds at each timestep). Default is False :return: bool """ return self.__persistent @property def inconsistency_check(self) -> bool: """Returns whether to check for inconsistencies in the interpretation or not. Default is True :return: bool """ return self.__inconsistency_check @property def static_graph_facts(self) -> bool: """Returns whether to make graph facts static or not. Default is True :return: bool """ return self.__static_graph_facts @property def store_interpretation_changes(self) -> bool: """Returns whether to keep track of changes that occur in the interpretation. You will not be able to view interpretation results after reasoning. Default is True :return: bool """ return self.__store_interpretation_changes @property def parallel_computing(self) -> bool: """Returns whether to use multiple CPU cores for inference. This will disable cacheing and pyreason will have to be re-compiled at each run - but after compilation it will be faster. Default is False :return: bool """ return self.__parallel_computing @property def update_mode(self) -> str: """Returns the way interpretations are going to be updated. This could be "intersection" or "override" :return: str """ return self.__update_mode @property def allow_ground_rules(self) -> bool: """Returns whether rules can have ground atoms or not. Default is False :return: bool """ return self.__allow_ground_rules @property def fp_version(self) -> bool: """Returns whether we are using the fixed point version or the optimized version. Default is false :return: bool """ return self.__fp_version @verbose.setter def verbose(self, value: bool) -> None: """Set verbose mode. Default is True :param value: verbose or not :raises TypeError: If not boolean, raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__verbose = value @output_to_file.setter def output_to_file(self, value: bool) -> None: """Set whether to put all output into a file. Default file name is `pyreason_output` and can be changed with `output_file_name`. Default is False :param value: whether to save to file or not :raises TypeError: If not boolean, raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__output_to_file = value @output_file_name.setter def output_file_name(self, file_name: str) -> None: """Set output file name if `output_to_file` is true. Default is `pyreason_output` :param file_name: File name :raises TypeError: If not string raise error """ if not isinstance(file_name, str): raise TypeError('file_name has to be a string') else: self.__output_file_name = file_name @graph_attribute_parsing.setter def graph_attribute_parsing(self, value: bool) -> None: """Whether to parse graphml file for attributes. Default is True :param value: Whether to parse graphml or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__graph_attribute_parsing = value @abort_on_inconsistency.setter def abort_on_inconsistency(self, value: bool) -> None: """Whether to abort program if inconsistency is found. Default is False :param value: Whether to abort on inconsistency or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__abort_on_inconsistency = value @memory_profile.setter def memory_profile(self, value: bool) -> None: """Whether to profile the program's memory usage. Default is False :param value: Whether to profile program's memory usage or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__memory_profile = value @reverse_digraph.setter def reverse_digraph(self, value: bool) -> None: """Whether to reverse the digraph. if the graphml contains an edge: a->b setting reverse as true will make the edge b->a. Default is False :param value: Whether to reverse graphml edges or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__reverse_digraph = value @atom_trace.setter def atom_trace(self, value: bool) -> None: """Whether to save all atoms that were responsible for the firing of a rule. NOTE: this can be very memory heavy. Default is False :param value: Whether to save all atoms or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__atom_trace = value @save_graph_attributes_to_trace.setter def save_graph_attributes_to_trace(self, value: bool) -> None: """Whether to save all graph attribute facts. Graphs are large so turning this on can be memory heavy NOTE: this can be very memory heavy. Default is False :param value: Whether to save all graph attribute facts in the trace or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__save_graph_attributes_to_trace = value @canonical.setter def canonical(self, value: bool) -> None: """Whether the interpretation should be canonical where bounds are reset at each timestep or not :param value: Whether to reset all bounds at each timestep (non-canonical) or not (canonical) :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__persistent = value @persistent.setter def persistent(self, value: bool) -> None: """Whether the interpretation should be canonical where bounds are reset at each timestep or not :param value: Whether to reset all bounds at each timestep (non-persistent) or (persistent) :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__persistent = value @inconsistency_check.setter def inconsistency_check(self, value: bool) -> None: """Whether to check for inconsistencies in the interpretation or not :param value: Whether to check for inconsistencies or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__inconsistency_check = value @static_graph_facts.setter def static_graph_facts(self, value: bool) -> None: """Whether to make graphml attribute facts static or not :param value: Whether to make graphml facts static or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__static_graph_facts = value @store_interpretation_changes.setter def store_interpretation_changes(self, value: bool) -> None: """Whether to keep track of changes that occur to the interpretation. You will not be able to view interpretation results after reasoning. :param value: Whether to make graphml facts static or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__store_interpretation_changes = value @parallel_computing.setter def parallel_computing(self, value: bool) -> None: """Whether to use multiple CPU cores for inference. This will disable cacheing and pyreason will have to be re-compiled at each run - but after compilation it will be faster. Default is False :param value: Whether to make inference run on parallel hardware (multiple CPU cores) :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__parallel_computing = value @update_mode.setter def update_mode(self, value: str) -> None: """The way interpretations are going to be updated. This could be "intersection" or "override". Default is 'intersection' :param value: "intersection" or "override" :raises TypeError: If not str raise error """ if not isinstance(value, str): raise TypeError('value has to be a str') else: self.__update_mode = value @allow_ground_rules.setter def allow_ground_rules(self, value: bool) -> None: """Allow ground atoms to be used in rules when possible. Default is False :param value: Whether to allow ground atoms or not :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__allow_ground_rules = value @fp_version.setter def fp_version(self, value: bool) -> None: """Set the fixed point or optimized version. Default is False :param value: Whether to use the fixed point version or the optimized version :raises TypeError: If not bool raise error """ if not isinstance(value, bool): raise TypeError('value has to be a bool') else: self.__fp_version = value # VARIABLES __graph: Optional[nx.DiGraph] = None __rules: Optional[numba.typed.List] = None __clause_maps: Optional[dict] = None __node_facts: Optional[numba.typed.List] = None __edge_facts: Optional[numba.typed.List] = None __facts_name_set = set() # We want to warn the user if they add multiple facts with the same name __rules_name_set = set() # We want to warn the user if they add multiple rules with the same name __ipl: Optional[numba.typed.List] = None __specific_node_labels: Optional[numba.typed.List] = None __specific_edge_labels: Optional[numba.typed.List] = None __closed_world_predicates = set() __non_fluent_graph_facts_node: Optional[numba.typed.List] = None __non_fluent_graph_facts_edge: Optional[numba.typed.List] = None __specific_graph_node_labels: Optional[numba.typed.List] = None __specific_graph_edge_labels: Optional[numba.typed.List] = None __annotation_functions = [] __head_functions = [] __timestamp = '' __program: Optional[Program] = None __graphml_parser = GraphmlParser()
[docs] settings = _Settings()
[docs] def reset(): """Resets certain variables to None to be able to do pr.reason() multiple times in a program without memory blowing up """ global __node_facts, __edge_facts, __graph, __facts_name_set, __closed_world_predicates # Facts __node_facts = None __edge_facts = None __facts_name_set.clear() __closed_world_predicates = set() if __program is not None: __program.reset_facts() # Graph __graph = None if __program is not None: __program.reset_graph() # Rules reset_rules()
[docs] def get_rules(): """ Returns the rules """ return __rules
[docs] def reset_rules(): """ Resets rules to none """ global __rules, __annotation_functions, __head_functions __rules = None __rules_name_set.clear() __annotation_functions = [] __head_functions = [] if __program is not None: __program.reset_rules()
[docs] def get_logic_program() -> Optional[Program]: """Get the logic program object :return: Logic program object """ global __program return __program
[docs] def get_interpretation() -> Optional[Interpretation]: """Get the current interpretation :return: Current interpretation """ global __program if __program is None: raise Exception('No interpretation found. Please run `pr.reason()` first') return __program.interp
[docs] def get_time() -> int: """Get the current time :return: Current time """ try: i = get_interpretation() except Exception: return 0 return i.time + 1
[docs] def reset_settings(): """ Resets settings to default """ settings.reset()
# FUNCTIONS
[docs] def load_graphml(path: str) -> None: """Loads graph from GraphMl file path into program :param path: Path for the GraphMl file """ global __graph, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels # Parse graph __graph = __graphml_parser.parse_graph(path, settings.reverse_digraph) # Graph attribute parsing if settings.graph_attribute_parsing: __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels = __graphml_parser.parse_graph_attributes(settings.static_graph_facts) else: __non_fluent_graph_facts_node = numba.typed.List.empty_list(fact_node.fact_type) __non_fluent_graph_facts_edge = numba.typed.List.empty_list(fact_edge.fact_type) __specific_graph_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.string)) __specific_graph_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.Tuple((numba.types.string, numba.types.string))))
[docs] def load_graph(graph: nx.DiGraph) -> None: """Load a networkx DiGraph into pyreason :param graph: Networkx DiGraph object to load into pyreason :type graph: nx.DiGraph :return: None """ global __graph, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels # Load graph __graph = __graphml_parser.load_graph(graph) # Graph attribute parsing if settings.graph_attribute_parsing: __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels = __graphml_parser.parse_graph_attributes(settings.static_graph_facts) else: __non_fluent_graph_facts_node = numba.typed.List.empty_list(fact_node.fact_type) __non_fluent_graph_facts_edge = numba.typed.List.empty_list(fact_edge.fact_type) __specific_graph_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.string)) __specific_graph_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.Tuple((numba.types.string, numba.types.string))))
[docs] def load_inconsistent_predicate_list(path: str) -> None: """Load IPL from YAML file path into program :param path: Path for the YAML IPL file """ global __ipl __ipl = yaml_parser.parse_ipl(path)
[docs] def add_inconsistent_predicate(pred1: str, pred2: str) -> None: """Add an inconsistent predicate pair to the IPL :param pred1: First predicate in the inconsistent pair :param pred2: Second predicate in the inconsistent pair """ global __ipl if __ipl is None: __ipl = numba.typed.List.empty_list(numba.types.Tuple((label.label_type, label.label_type))) __ipl.append((label.Label(pred1), label.Label(pred2)))
[docs] def add_rule(pr_rule: Rule) -> None: """Add a rule to pyreason from text format. This format is not as modular as the YAML format. """ global __rules # Add to collection of rules if __rules is None: __rules = numba.typed.List.empty_list(rule.rule_type) # Generate name for rule if not set if pr_rule.rule.get_rule_name() is None: pr_rule.rule.set_rule_name(f'rule_{len(__rules)}') if pr_rule.rule.get_rule_name() in __rules_name_set: warnings.warn(f"Rule {pr_rule.rule.get_rule_name()} has already been added. Duplicate rule names will lead to an ambiguous rule trace.") __rules_name_set.add(pr_rule.rule.get_rule_name()) __rules.append(pr_rule.rule)
[docs] def add_rules_from_file(file_path: str, infer_edges: bool = False, raise_errors: bool = False) -> None: """Add a set of rules from a text file. Each non-empty, non-comment line is treated as a rule in text format. Lines starting with ``#`` are treated as comments and skipped. The ``infer_edges`` parameter is applied uniformly to all rules loaded from the file. :param file_path: Path to the text file containing rules :type file_path: str :param infer_edges: Whether to infer edges on these rules if an edge doesn't exist between head variables and the body of the rule is satisfied :type infer_edges: bool :param raise_errors: If True, raise on invalid rules. If False, warn and skip them. :type raise_errors: bool :return: None :raises FileNotFoundError: If the text file doesn't exist :raises ValueError: If rule parsing fails and raise_errors is True """ with open(file_path, 'r') as file: rules = [line.rstrip() for line in file if line.rstrip() != '' and line.rstrip()[0] != '#'] loaded_count = 0 error_count = 0 rule_offset = 0 if __rules is None else len(__rules) for i, r in enumerate(rules): try: add_rule(Rule(r, f'rule_{i+rule_offset}', infer_edges)) loaded_count += 1 except Exception as e: if raise_errors: raise ValueError(f"Line {i + 1}: Failed to parse rule '{r}' - {e}") from e error_count += 1 warnings.warn(f"Line {i + 1}: Failed to parse rule '{r}' - {e}") if settings.verbose: print(f"Loaded {loaded_count} rules from {file_path}") if error_count > 0: print(f"Failed to load {error_count} rules due to errors")
def _parse_bool_param(raw_value, param_name, idx, raise_errors, item_label="Item", default=False): """Private helper to parse a raw value as a boolean. :param raw_value: Raw value to parse (can be None, str, bool, int, float) :param param_name: Name of the parameter (for error messages) :param idx: Index of the item being parsed (for error messages) :param raise_errors: Whether to raise errors or just warn :param item_label: Label for error messages (e.g., "Item", "Row") :param default: Default value if raw_value is None or empty :return: Parsed boolean value :raises ValueError: If validation fails and raise_errors is True """ if raw_value is None: return default if isinstance(raw_value, bool): return raw_value if isinstance(raw_value, str): val_str = raw_value.strip().lower() if val_str in ('true', '1', 'yes', 't', 'y'): return True elif val_str in ('false', '0', 'no', 'f', 'n', ''): return default if val_str == '' else False else: if raise_errors: raise ValueError(f"{item_label} {idx}: Invalid {param_name} value '{raw_value}'") warnings.warn(f"{item_label} {idx}: Invalid {param_name} value '{raw_value}', using default value") return default if isinstance(raw_value, (int, float)): return bool(raw_value) if raise_errors: raise ValueError(f"{item_label} {idx}: Invalid {param_name} value type '{type(raw_value).__name__}'") warnings.warn(f"{item_label} {idx}: Invalid {param_name} value type '{type(raw_value).__name__}', using default value") return default def _parse_and_validate_rule_params(idx, name_raw, infer_edges_raw, set_static_raw, raise_errors, item_label="Item"): """Private helper to parse and validate rule parameters. :param idx: Index of the item being parsed (for error messages) :param name_raw: Raw name value (can be None, str, or other types) :param infer_edges_raw: Raw infer_edges value :param set_static_raw: Raw set_static value :param raise_errors: Whether to raise errors or just warn :param item_label: Label for error messages (e.g., "Item", "Row") :return: Tuple of (name, infer_edges, set_static) :raises ValueError: If validation fails and raise_errors is True """ # Parse name name = None if name_raw is not None: name = str(name_raw).strip() if str(name_raw).strip() else None # Parse infer_edges infer_edges = _parse_bool_param(infer_edges_raw, 'infer_edges', idx, raise_errors, item_label, default=False) # Parse set_static set_static = _parse_bool_param(set_static_raw, 'set_static', idx, raise_errors, item_label, default=False) return name, infer_edges, set_static
[docs] def add_rule_from_csv(csv_path: str, raise_errors: bool = True) -> None: """Load multiple rules from a CSV file. Each row should have up to 4 comma-separated values in this order: ``rule_text, name, infer_edges, set_static`` - **rule_text** (required): The rule in text format, e.g., ``friend(A, B) <- knows(A, B)`` or ``"ally(A, B) <- friend(A, B), common_interest(A, B)"`` for rules with commas. - **name** (optional): A unique name for the rule (can be empty). - **infer_edges** (optional): Whether to infer new edges after edge rule fires (default: False). Accepts: True/False, 1/0, yes/no (case-insensitive). - **set_static** (optional): Whether to set the atom in the head as static if the rule fires (default: False). Accepts: True/False, 1/0, yes/no (case-insensitive). A header row is optional. If included, it must be exactly:: rule_text,name,infer_edges,set_static Any other header format will be treated as a data row and will likely raise a parsing error. Example CSV:: rule_text,name,infer_edges,set_static friend(A, B) <- knows(A, B),friendship-rule,False,False "ally(A, B) <- friend(A, B), common_interest(A, B)",ally-rule,False,False connected(A, B) <- link(A, B),connected-rule,True,False :param csv_path: Path to the CSV file containing rules :type csv_path: str :param raise_errors: If True, raise on invalid rows. If False, warn and skip them. :type raise_errors: bool :return: None :raises FileNotFoundError: If the CSV file doesn't exist :raises ValueError: If rule parsing fails or CSV format is invalid """ try: df = pd.read_csv(csv_path, header=None, dtype=str, keep_default_na=False) except FileNotFoundError: raise FileNotFoundError(f"CSV file not found: {csv_path}") except pd.errors.EmptyDataError: warnings.warn(f"CSV file {csv_path} is empty, no rules loaded") return except Exception as e: raise ValueError(f"Error reading CSV file {csv_path}: {e}") if df.empty: warnings.warn(f"CSV file {csv_path} is empty, no rules loaded") return # Skip first row if it exactly matches the expected header expected_header = ['rule_text', 'name', 'infer_edges', 'set_static'] first_row = [str(v).strip() for v in df.iloc[0]] if len(df) > 0 else [] has_header = first_row == expected_header start_idx = 1 if has_header else 0 # Track loaded rules for reporting loaded_count = 0 error_count = 0 loaded_name_set = set() # Process each row for idx, row in df.iloc[start_idx:].iterrows(): try: # Extract rule_text (required, column 0) if len(row) < 1 or not str(row[0]).strip(): if raise_errors: raise ValueError(f"Row {idx + 1}: Missing required 'rule_text'") warnings.warn(f"Row {idx + 1}: Missing required 'rule_text', skipping row") error_count += 1 continue rule_text = str(row[0]).strip() # Parse and validate parameters using shared helper name, infer_edges, set_static = _parse_and_validate_rule_params( idx + 1, row[1] if len(row) > 1 else None, row[2] if len(row) > 2 else None, row[3] if len(row) > 3 else None, raise_errors, "Row" ) # Check for duplicate names if name and name in loaded_name_set: if raise_errors: raise ValueError(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all rule names must be unique.") warnings.warn(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all rule names must be unique.") error_count += 1 continue if name: loaded_name_set.add(name) # Create and add the rule r = Rule(rule_text=rule_text, name=name, infer_edges=infer_edges, set_static=set_static) add_rule(r) loaded_count += 1 except ValueError as e: if raise_errors: raise ValueError(f"Row {idx + 1}: Failed to parse rule - {e}") from e error_count += 1 warnings.warn(f"Row {idx + 1}: Failed to parse rule - {e}") except Exception as e: if raise_errors: raise Exception(f"Row {idx + 1}: Unexpected error - {e}") from e error_count += 1 warnings.warn(f"Row {idx + 1}: Unexpected error - {e}") if settings.verbose: print(f"Loaded {loaded_count} rules from {csv_path}") if error_count > 0: print(f"Failed to load {error_count} rules due to errors")
[docs] def add_rule_from_json(json_path: str, raise_errors: bool = True) -> None: """Load multiple rules from a JSON file. The JSON should be an array of objects, where each object represents a Rule with these fields: - **rule_text** (required): The rule in text format, e.g., ``"friend(A, B) <- knows(A, B)"`` - **name** (optional): The name of the rule. This will appear in the rule trace. - **infer_edges** (optional): Whether to infer new edges after edge rule fires (default: false). - **set_static** (optional): Whether to set the atom in the head as static if the rule fires (default: false). - **custom_thresholds** (optional): A list of threshold objects (one per clause), or a dict mapping clause index to threshold object (unspecified clauses get defaults). Each threshold object has ``quantifier``, ``quantifier_type``, and ``thresh`` fields. - **weights** (optional): A list of weights for the rule clauses. This is passed to an annotation function. Example JSON format:: [ { "rule_text": "friend(A, B) <- knows(A, B)", "name": "friendship-rule", "infer_edges": false, "set_static": false }, { "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", "name": "ally-rule-list", "custom_thresholds": [ {"quantifier": "greater_equal", "quantifier_type": ["number", "total"], "thresh": 1}, {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 100} ], "weights": [1.0, 2.0] }, { "rule_text": "ally(A, B) <- friend(A, B), common_interest(A, B)", "name": "ally-rule-dict", "custom_thresholds": { "0": {"quantifier": "greater_equal", "quantifier_type": ["percent", "total"], "thresh": 50} } } ] :param json_path: Path to the JSON file containing rules :type json_path: str :param raise_errors: If True, raise on invalid items. If False, warn and skip them. :type raise_errors: bool :return: None :raises FileNotFoundError: If the JSON file doesn't exist :raises ValueError: If rule parsing fails or JSON format is invalid """ try: with open(json_path, 'r') as f: data = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"JSON file not found: {json_path}") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format in file {json_path}: {e}") except Exception as e: raise ValueError(f"Error reading JSON file {json_path}: {e}") if not isinstance(data, list): raise ValueError(f"JSON file must contain an array of rule objects, got {type(data).__name__}") if len(data) == 0: warnings.warn(f"JSON file {json_path} contains an empty array, no rules loaded") return # Track loaded rules for reporting loaded_count = 0 error_count = 0 loaded_name_set = set() # Process each rule object for idx, rule_obj in enumerate(data): try: if not isinstance(rule_obj, dict): if raise_errors: raise ValueError(f"Item {idx}: Expected object, got {type(rule_obj).__name__}") warnings.warn(f"Item {idx}: Expected object, got {type(rule_obj).__name__}, skipping item") error_count += 1 continue # Extract rule_text (required) rule_text = rule_obj.get('rule_text') if not rule_text or not str(rule_text).strip(): if raise_errors: raise ValueError(f"Item {idx}: Missing required 'rule_text'") warnings.warn(f"Item {idx}: Missing required 'rule_text', skipping item") error_count += 1 continue rule_text = str(rule_text).strip() # Parse and validate parameters using shared helper name, infer_edges, set_static = _parse_and_validate_rule_params( idx, rule_obj.get('name'), rule_obj.get('infer_edges', False), rule_obj.get('set_static', False), raise_errors, "Item" ) # Extract advanced params (JSON-only) custom_thresholds_raw = rule_obj.get('custom_thresholds') custom_thresholds = None found_threshold_error = False if custom_thresholds_raw is not None: if isinstance(custom_thresholds_raw, list): custom_thresholds = [] for t_idx, t_obj in enumerate(custom_thresholds_raw): if isinstance(t_obj, dict): try: custom_thresholds.append(Threshold( t_obj['quantifier'], tuple(t_obj['quantifier_type']), t_obj['thresh'] )) except (KeyError, ValueError, TypeError) as te: if raise_errors: raise ValueError(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}") warnings.warn(f"Item {idx}, threshold {t_idx}: Invalid threshold - {te}, skipping rule") found_threshold_error = True break else: if raise_errors: raise ValueError(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}") warnings.warn(f"Item {idx}, threshold {t_idx}: Expected object, got {type(t_obj).__name__}, skipping rule") found_threshold_error = True break elif isinstance(custom_thresholds_raw, dict): custom_thresholds = {} for key_str, t_obj in custom_thresholds_raw.items(): try: clause_idx = int(key_str) except (ValueError, TypeError): if raise_errors: raise ValueError(f"Item {idx}: custom_thresholds dict key '{key_str}' must be an integer clause index") warnings.warn(f"Item {idx}: custom_thresholds dict key '{key_str}' must be an integer clause index, skipping rule") found_threshold_error = True break if isinstance(t_obj, dict): try: custom_thresholds[clause_idx] = Threshold( t_obj['quantifier'], tuple(t_obj['quantifier_type']), t_obj['thresh'] ) except (KeyError, ValueError, TypeError) as te: if raise_errors: raise ValueError(f"Item {idx}, threshold key '{key_str}': Invalid threshold - {te}") warnings.warn(f"Item {idx}, threshold key '{key_str}': Invalid threshold - {te}, skipping rule") found_threshold_error = True break else: if raise_errors: raise ValueError(f"Item {idx}, threshold key '{key_str}': Expected object, got {type(t_obj).__name__}") warnings.warn(f"Item {idx}, threshold key '{key_str}': Expected object, got {type(t_obj).__name__}, skipping rule") found_threshold_error = True break else: if raise_errors: raise ValueError(f"Item {idx}: 'custom_thresholds' must be a list or dict of threshold objects") warnings.warn(f"Item {idx}: 'custom_thresholds' must be a list or dict of threshold objects, skipping rule") found_threshold_error = True if found_threshold_error: error_count += 1 continue weights_raw = rule_obj.get('weights') weights = None if weights_raw is not None: if not isinstance(weights_raw, list): if raise_errors: raise ValueError(f"Item {idx}: 'weights' must be a list of numeric values") warnings.warn(f"Item {idx}: 'weights' must be a list of numeric values, skipping rule") error_count += 1 continue else: weights = weights_raw # Check for duplicate names if name and name in loaded_name_set: if raise_errors: raise ValueError(f"Item {idx}: Loaded name '{name}' is a duplicate - all rule names must be unique.") warnings.warn(f"Item {idx}: Loaded name '{name}' is a duplicate - all rule names must be unique.") error_count += 1 continue if name: loaded_name_set.add(name) # Create and add the rule r = Rule(rule_text=rule_text, name=name, infer_edges=infer_edges, set_static=set_static, custom_thresholds=custom_thresholds, weights=weights) add_rule(r) loaded_count += 1 except ValueError as e: if raise_errors: raise ValueError(f"Item {idx}: Failed to parse rule - {e}") from e error_count += 1 warnings.warn(f"Item {idx}: Failed to parse rule - {e}") except Exception as e: if raise_errors: raise Exception(f"Item {idx}: Unexpected error - {e}") from e error_count += 1 warnings.warn(f"Item {idx}: Unexpected error - {e}") if settings.verbose: print(f"Loaded {loaded_count} rules from {json_path}") if error_count > 0: print(f"Failed to load {error_count} rules due to errors")
def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw, static_raw, raise_errors, item_label="Item"): """Private helper to parse and validate fact parameters. :param idx: Index of the item being parsed (for error messages) :param name_raw: Raw name value (can be None, str, or other types) :param start_time_raw: Raw start_time value :param end_time_raw: Raw end_time value :param static_raw: Raw static value :param raise_errors: Whether to raise errors or just warn :param item_label: Label for error messages (e.g., "Item", "Row") :return: Tuple of (name, start_time, end_time, static) or None if validation fails :raises ValueError: If validation fails and raise_errors is True """ # Parse name name = None if name_raw is not None: name = str(name_raw).strip() if str(name_raw).strip() else None # Parse start_time try: start_time = int(start_time_raw) if start_time_raw is not None and str(start_time_raw).strip() else 0 except (ValueError, TypeError, AttributeError): if raise_errors: raise ValueError(f"{item_label} {idx}: Invalid start_time '{start_time_raw}'") warnings.warn(f"{item_label} {idx}: Invalid start_time '{start_time_raw}', using default value") start_time = 0 # Parse end_time try: end_time = int(end_time_raw) if end_time_raw is not None and str(end_time_raw).strip() else 0 except (ValueError, TypeError, AttributeError): if raise_errors: raise ValueError(f"{item_label} {idx}: Invalid end_time '{end_time_raw}'") warnings.warn(f"{item_label} {idx}: Invalid end_time '{end_time_raw}', using default value") end_time = start_time # Parse static as boolean static = _parse_bool_param(static_raw, 'static', idx, raise_errors, item_label, default=False) return name, start_time, end_time, static
[docs] def add_closed_world_predicate(predicate_name: str) -> None: """Register a predicate as closed_world (circumscription). For any node/edge where a closed_world predicate has bounds [0,1] (unknown), it will be treated as [0,0] (false) during rule satisfaction checks. :param predicate_name: The name of the predicate to minimize :return: None """ __closed_world_predicates.add(predicate_name)
[docs] def add_fact(pyreason_fact: Fact) -> None: """Add a PyReason fact to the program. :param pyreason_fact: PyReason fact created using pr.Fact(...) :return: None """ global __node_facts, __edge_facts if __node_facts is None: __node_facts = numba.typed.List.empty_list(fact_node.fact_type) if __edge_facts is None: __edge_facts = numba.typed.List.empty_list(fact_edge.fact_type) if pyreason_fact.type == 'node': if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' if pyreason_fact.name in __facts_name_set: warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.", stacklevel=2) f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) __facts_name_set.add(pyreason_fact.name) __node_facts.append(f) else: if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' if pyreason_fact.name in __facts_name_set: warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.", stacklevel=2) f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) __facts_name_set.add(pyreason_fact.name) __edge_facts.append(f)
[docs] def add_fact_from_json(json_path: str, raise_errors = True) -> None: """Load multiple facts from a JSON file. The JSON should be an array of objects, where each object represents a Fact with these fields: - fact_text (required): The fact in text format, e.g., 'pred(x,y) : [0.2, 1]' or 'pred(x) : True' - name (optional): The name of the fact - start_time (optional): The timestep at which this fact becomes active (default: 0) - end_time (optional): The last timestep this fact is active (default: 0) - static (optional): Whether the fact is static for the entire program (default: false) Example JSON format: ```json [ { "fact_text": "Viewed(Zach)", "name": "seen-fact-zach", "start_time": 0, "end_time": 3, "static": false }, { "fact_text": "Viewed(Justin)", "name": "seen-fact-justin", "start_time": 0, "end_time": 3, "static": false }, { "fact_text": "Viewed(Michelle)", "start_time": 1, "end_time": 3 } ] ``` :param json_path: Path to the JSON file containing facts :type json_path: str :return: None :raises FileNotFoundError: If the JSON file doesn't exist :raises ValueError: If fact parsing fails or JSON format is invalid """ try: with open(json_path, 'r') as f: data = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"JSON file not found: {json_path}") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format in file {json_path}: {e}") except Exception as e: raise ValueError(f"Error reading JSON file {json_path}: {e}") if not isinstance(data, list): raise ValueError(f"JSON file must contain an array of fact objects, got {type(data).__name__}") if len(data) == 0: warnings.warn(f"JSON file {json_path} contains an empty array, no facts loaded") return # Track loaded facts for reporting loaded_count = 0 error_count = 0 loaded_name_set = set() # Process each fact object for idx, fact_obj in enumerate(data): try: if not isinstance(fact_obj, dict): if raise_errors: raise ValueError(f"Item {idx}: Expected object, got {type(fact_obj).__name__}") warnings.warn(f"Item {idx}: Expected object, got {type(fact_obj).__name__}, skipping item") error_count += 1 continue # Extract fact_text (required) fact_text = fact_obj.get('fact_text') if not fact_text or not str(fact_text).strip(): if raise_errors: raise ValueError(f"Item {idx}: Missing required 'fact_text'") warnings.warn(f"Item {idx}: Missing required 'fact_text', skipping item") error_count += 1 continue fact_text = str(fact_text).strip() # Parse and validate parameters using shared helper name, start_time, end_time, static = _parse_and_validate_fact_params( idx, fact_obj.get('name'), fact_obj.get('start_time', 0), fact_obj.get('end_time', 0), fact_obj.get('static', False), raise_errors, "Item" ) # Check for duplicate names if name and name in loaded_name_set: if raise_errors: raise ValueError(f"Item {idx}: Loaded name '{name}' is a duplicate - all fact names must be unique.") warnings.warn(f"Item {idx}: Loaded name '{name}' is a duplicate - all fact names must be unique.") error_count += 1 continue if name: loaded_name_set.add(name) # Create and add the fact fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static) add_fact(fact) loaded_count += 1 except ValueError as e: if raise_errors: raise ValueError(f"Item {idx}: Failed to parse fact - {e}") from e error_count += 1 warnings.warn(f"Item {idx}: Failed to parse fact - {e}") except Exception as e: if raise_errors: raise Exception(f"Item {idx}: Unexpected error - {e}") from e error_count += 1 warnings.warn(f"Item {idx}: Unexpected error - {e}") # Report results print(f"Loaded {loaded_count} facts from {json_path}") if error_count > 0: print(f"Failed to load {error_count} facts due to errors")
[docs] def add_fact_from_csv(csv_path: str, raise_errors = True) -> None: """Load multiple facts from a CSV file. Each row should have up to 5 comma-separated values in this order: ``fact_text, name, start_time, end_time, static`` - **fact_text** (required): The fact in text format, e.g., ``Viewed(Zach)`` or ``"HaveAccess(Zach,TextMessage)"`` or ``"Processed(Node1):[0.5,0.8]"`` for interval bounds. - **name** (optional): A unique name for the fact (can be empty). - **start_time** (optional): The timestep at which this fact becomes active (default: 0). - **end_time** (optional): The last timestep this fact is active (default: 0). - **static** (optional): Whether the fact is static for the entire program (default: False). Accepts: True/False, 1/0, yes/no (case-insensitive). A header row is optional. If included, it must be exactly:: fact_text,name,start_time,end_time,static Any other header format will be treated as a data row and will likely raise a parsing error. Example CSV:: fact_text,name,start_time,end_time,static Viewed(Zach),seen-fact-zach,0,3,False Viewed(Justin),seen-fact-justin,0,3,true "HaveAccess(Zach,TextMessage)",access-zach,0,5,True "Processed(Node1):[0.5,0.8]",interval-node,0,10,False Viewed(Eve),,,, :param csv_path: Path to the CSV file containing facts :type csv_path: str :param raise_errors: If True, raise on invalid rows. If False, warn and skip them. :type raise_errors: bool :return: None :raises FileNotFoundError: If the CSV file doesn't exist :raises ValueError: If fact parsing fails or CSV format is invalid """ try: # Read CSV file - don't assume there's a header df = pd.read_csv(csv_path, header=None, dtype=str, keep_default_na=False) except FileNotFoundError: raise FileNotFoundError(f"CSV file not found: {csv_path}") except pd.errors.EmptyDataError: # Handle completely empty files warnings.warn(f"CSV file {csv_path} is empty, no facts loaded") return except Exception as e: raise ValueError(f"Error reading CSV file {csv_path}: {e}") if df.empty: warnings.warn(f"CSV file {csv_path} is empty, no facts loaded") return # Skip first row if it exactly matches the expected header expected_header = ['fact_text', 'name', 'start_time', 'end_time', 'static'] first_row = [str(v).strip() for v in df.iloc[0]] if len(df) > 0 else [] has_header = first_row == expected_header start_idx = 1 if has_header else 0 # Track loaded facts for reporting loaded_count = 0 error_count = 0 loaded_name_set = set() # Process each row for idx, row in df.iloc[start_idx:].iterrows(): try: # Extract fact_text (required, column 0) if len(row) < 1 or not str(row[0]).strip(): if raise_errors: raise ValueError(f"Row {idx + 1}: Missing required 'fact_text'") warnings.warn(f"Row {idx + 1}: Missing required 'fact_text', skipping row") error_count += 1 continue fact_text = str(row[0]).strip() # Parse and validate parameters using shared helper name, start_time, end_time, static = _parse_and_validate_fact_params( idx + 1, row[1] if len(row) > 1 else None, row[2] if len(row) > 2 else None, row[3] if len(row) > 3 else None, row[4] if len(row) > 4 else None, raise_errors, "Row" ) # Check for duplicate names if name and name in loaded_name_set: if raise_errors: raise ValueError(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all fact names must be unique.") warnings.warn(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all fact names must be unique.") error_count += 1 continue if name: loaded_name_set.add(name) # Create and add the fact fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static) add_fact(fact) loaded_count += 1 except ValueError as e: if raise_errors: raise ValueError(f"Row {idx + 1}: Failed to parse fact - {e}") from e error_count += 1 warnings.warn(f"Row {idx + 1}: Failed to parse fact - {e}") except Exception as e: if raise_errors: raise Exception(f"Row {idx + 1}: Unexpected error - {e}") from e error_count += 1 warnings.warn(f"Row {idx + 1}: Unexpected error - {e}") # Report results if settings.verbose: print(f"Loaded {loaded_count} facts from {csv_path}") if error_count > 0: print(f"Failed to load {error_count} facts due to errors")
[docs] def add_annotation_function(function: Callable) -> None: """Function to add annotation functions to PyReason. The added functions can be used in rules. The function must be ``@numba.njit``-decorated and must accept exactly one of the two supported signatures: - 2 args (legacy):: def fn(annotations, weights) -> Tuple[float, float] ``annotations`` is a per-clause list of bounds for the atoms that satisfied that clause. The relational join across body variables is performed inside the engine and projected away before this hook fires, so 2-arg functions cannot recover which grounding produced which bound. - 6 args (extended):: def fn(annotations, weights, qualified_nodes, qualified_edges, clause_labels, clause_variables) -> Tuple[float, float] Same as the legacy signature, plus four extra args that expose the per-clause structure the engine already builds: ==================== ===================================================== ``annotations[i]`` bounds of atoms that satisfied clause ``i`` ``qualified_nodes[i]`` nodes that satisfied clause ``i`` (empty for edge clauses) ``qualified_edges[i]`` edges that satisfied clause ``i`` (empty for node clauses) ``clause_labels[i]`` predicate label of clause ``i`` ``clause_variables[i]`` variable names of clause ``i`` (length 1 for node clauses, length 2 for edge clauses) ==================== ===================================================== Comparison clauses are not surfaced, so the list lengths may be less than ``len(rule.get_clauses())``. Match clauses by predicate name and variable role, not by position in the rule body — the parser's ``reorder_clauses`` optimization may rewrite the order. Arity validation runs at registration time, so anything other than 2 or 6 raises ``TypeError`` here rather than producing a confusing failure inside the reasoning loop. :param function: Function to be added. Must be ``@numba.njit``-decorated and must match one of the two supported signatures above. :type function: Callable :return: None :raises TypeError: if ``function`` does not have exactly 2 or 6 positional arguments. """ # Make sure that the functions are jitted so that they can be passed around in other jitted functions # TODO: Remove if necessary # assert hasattr(function, 'nopython_signatures'), 'The function to be added has to be under a `numba.njit` decorator' # Arity gate: only 2-arg and 6-arg signatures are supported by `annotate`. # Validating here keeps the error close to the user's call site and avoids # `raise` inside numba.objmode (which would fail with_lifting). py_func = getattr(function, 'py_func', function) nargs = py_func.__code__.co_argcount if nargs != 2 and nargs != 6: raise TypeError( f"Annotation function {py_func.__name__!r} must accept exactly 2 positional " f"args (annotations, weights) or exactly 6 positional args (annotations, " f"weights, qualified_nodes, qualified_edges, clause_labels, clause_variables); " f"got {nargs}." ) __annotation_functions.append(function)
[docs] def add_head_function(function: Callable) -> None: """Function to add head functions to PyReason. The added functions can be used in rules :param function: Function to be added. This has to be under a numba `njit` decorator. function has signature: one parameter as input -- annotations :type function: Callable :return: None """ # Make sure that the functions are jitted so that they can be passed around in other jitted functions # TODO: Remove if necessary # assert hasattr(function, 'nopython_signatures'), 'The function to be added has to be under a `numba.njit` decorator' __head_functions.append(function)
[docs] def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, restart: bool = True): """Function to start the main reasoning process. Graph and rules must already be loaded. :param timesteps: Max number of timesteps to run. -1 specifies run till convergence. If reasoning again, this is the number of timesteps to reason for extra (no zero timestep), defaults to -1 :param convergence_threshold: Maximum number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1 :param convergence_bound_threshold: Maximum change in any interpretation (bounds) between timesteps or fixed point operations until considered convergent, defaults to -1 :param queries: A list of PyReason query objects that can be used to filter the ruleset based on the query. Default is None :param again: Whether to reason again on an existing interpretation, defaults to False :param restart: Whether to restart the program time from 0 when reasoning again, defaults to True :return: The final interpretation after reasoning. """ global __timestamp # Timestamp for saving files __timestamp = time.strftime('%Y%m%d-%H%M%S') if settings.output_to_file: sys.stdout = open(f"./{settings.output_file_name}_{__timestamp}.txt", "a") if not again or __program is None: if settings.memory_profile: start_mem = mp.memory_usage(max_usage=True) mem_usage, interp = mp.memory_usage((_reason, [timesteps, convergence_threshold, convergence_bound_threshold, queries]), max_usage=True, retval=True) print(f"\nProgram used {mem_usage-start_mem} MB of memory") else: interp = _reason(timesteps, convergence_threshold, convergence_bound_threshold, queries) else: if settings.memory_profile: start_mem = mp.memory_usage(max_usage=True) mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, restart, convergence_threshold, convergence_bound_threshold]), max_usage=True, retval=True) print(f"\nProgram used {mem_usage-start_mem} MB of memory") else: interp = _reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold) return interp
def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queries): # Globals global __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels global __program # Assert variables are of correct type if settings.output_to_file: sys.stdout = open(f"./{settings.output_file_name}_{__timestamp}.txt", "a") # Check variables that HAVE to be set. Exceptions if __graph is None: load_graph(nx.DiGraph()) if settings.verbose: warnings.warn('Graph not loaded. Use `load_graph` to load the graphml file. Using empty graph') if __rules is None: raise Exception('There are no rules, use `add_rule` or `add_rules_from_file`') if __node_facts is None: __node_facts = numba.typed.List.empty_list(fact_node.fact_type) if __edge_facts is None: __edge_facts = numba.typed.List.empty_list(fact_edge.fact_type) if __ipl is None: __ipl = numba.typed.List.empty_list(numba.types.Tuple((label.label_type, label.label_type))) # Add results of graph parse to existing specific labels and facts __specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.string)) __specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.Tuple((numba.types.string, numba.types.string)))) for label_name, nodes in __specific_graph_node_labels.items(): if label_name in __specific_node_labels: __specific_node_labels[label_name].extend(nodes) else: __specific_node_labels[label_name] = nodes for label_name, edges in __specific_graph_edge_labels.items(): if label_name in __specific_edge_labels: __specific_edge_labels[label_name].extend(edges) else: __specific_edge_labels[label_name] = edges all_node_facts = numba.typed.List.empty_list(fact_node.fact_type) all_edge_facts = numba.typed.List.empty_list(fact_edge.fact_type) all_node_facts.extend(numba.typed.List(__node_facts)) all_edge_facts.extend(numba.typed.List(__edge_facts)) all_node_facts.extend(__non_fluent_graph_facts_node) all_edge_facts.extend(__non_fluent_graph_facts_edge) # Atom trace cannot be true when store interpretations is false if not settings.store_interpretation_changes: settings.atom_trace = False # Convert list of annotation functions into tuple to be numba compatible annotation_functions = tuple(__annotation_functions) head_functions = tuple(__head_functions) # Filter rules based on queries if settings.verbose: print('Filtering rules based on queries') if queries is not None: __rules = ruleset_filter.filter_ruleset(queries, __rules) # Optimize rules by moving clauses around, only if there are more edges than nodes in the graph __clause_maps = {r.get_rule_name(): {i: i for i in range(len(r.get_clauses()))} for r in __rules} if len(__graph.edges) > len(__graph.nodes): if settings.verbose: print('Optimizing rules by moving node clauses ahead of edge clauses') __rules_copy = __rules.copy() __rules = numba.typed.List.empty_list(rule.rule_type) for i, r in enumerate(__rules_copy): r, __clause_maps[r.get_rule_name()] = reorder_clauses(r) __rules.append(r) # Setup logical program __program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, head_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.persistent, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules, settings.fp_version) __program.specific_node_labels = __specific_node_labels __program.specific_edge_labels = __specific_edge_labels # Convert closed_world predicates to numba-compatible list of label types closed_world_preds_numba = numba.typed.List.empty_list(label.label_type) for pred_name in __closed_world_predicates: closed_world_preds_numba.append(label.Label(pred_name)) __program.closed_world_predicates = closed_world_preds_numba # Run Program and get final interpretation interpretation = __program.reason(timesteps, convergence_threshold, convergence_bound_threshold, settings.verbose) # Clear facts after reasoning, so that reasoning again is possible with any added facts __node_facts = None __edge_facts = None return interpretation def _reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold): # Globals assert __program is not None, 'To run `reason_again` you need to have reasoned once before' # Extend facts all_node_facts = numba.typed.List.empty_list(fact_node.fact_type) all_edge_facts = numba.typed.List.empty_list(fact_edge.fact_type) all_node_facts.extend(numba.typed.List(__node_facts)) all_edge_facts.extend(numba.typed.List(__edge_facts)) # Run Program and get final interpretation interpretation = __program.reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold, all_node_facts, all_edge_facts, settings.verbose) return interpretation
[docs] def save_rule_trace(interpretation, folder: str='./'): """Saves the trace of the program. This includes every change that has occurred to the interpretation. If `atom_trace` was set to true this gives us full explainability of why interpretations changed :param interpretation: the output of `pyreason.reason()`, the final interpretation :param folder: the folder in which to save the result, defaults to './' """ assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace' output = Output(__timestamp, __clause_maps) output.save_rule_trace(interpretation, folder)
[docs] def get_rule_trace(interpretation) -> Tuple[pd.DataFrame, pd.DataFrame]: """Returns the trace of the program as 2 pandas dataframes (one for nodes, one for edges). This includes every change that has occurred to the interpretation. If `atom_trace` was set to true this gives us full explainability of why interpretations changed :param interpretation: the output of `pyreason.reason()`, the final interpretation :returns two pandas dataframes (nodes, edges) representing the changes that occurred during reasoning """ assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace' output = Output(__timestamp, __clause_maps) return output.get_rule_trace(interpretation)
[docs] def filter_and_sort_nodes(interpretation, labels: List[str], bound: interval.Interval=interval.closed(0,1), sort_by: str='lower', descending: bool=True): """Filters and sorts the node changes in the interpretation and returns as a list of Pandas dataframes that are easy to access :param interpretation: the output of `pyreason.reason()`, the final interpretation :param labels: A list of strings, labels that are in the interpretation that are to be filtered :param bound: The bound that will filter any interpretation that is not in it. the default does not filter anything, defaults to interval.closed(0,1) :param sort_by: String that is either 'lower' or 'upper', sorts by the lower/upper bound, defaults to 'lower' :param descending: A bool that sorts by descending/ascending order, defaults to True :return: A list of Pandas dataframes that contain the filtered and sorted interpretations that are easy to access """ assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to filter and sort nodes' filterer = Filter(interpretation.time) filtered_df = filterer.filter_and_sort_nodes(interpretation, labels, bound, sort_by, descending) return filtered_df
[docs] def filter_and_sort_edges(interpretation, labels: List[str], bound: interval.Interval=interval.closed(0,1), sort_by: str='lower', descending: bool=True): """Filters and sorts the edge changes in the interpretation and returns as a list of Pandas dataframes that are easy to access :param interpretation: the output of `pyreason.reason()`, the final interpretation :param labels: A list of strings, labels that are in the interpretation that are to be filtered :param bound: The bound that will filter any interpretation that is not in it. the default does not filter anything, defaults to interval.closed(0,1) :param sort_by: String that is either 'lower' or 'upper', sorts by the lower/upper bound, defaults to 'lower' :param descending: A bool that sorts by descending/ascending order, defaults to True :return: A list of Pandas dataframes that contain the filtered and sorted interpretations that are easy to access """ assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to filter and sort edges' filterer = Filter(interpretation.time) filtered_df = filterer.filter_and_sort_edges(interpretation, labels, bound, sort_by, descending) return filtered_df