from typing import Union, Tuple
import pyreason.scripts.numba_wrapper.numba_types.world_type as world
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
from pyreason.scripts.interpretation.interpretation_dict import InterpretationDict
import numba
from numba import objmode, prange
# Types for the dictionaries
[docs]
node_type = numba.types.string
[docs]
edge_type = numba.types.UniTuple(numba.types.string, 2)
# Type for storing list of qualified nodes/edges
[docs]
list_of_nodes = numba.types.ListType(node_type)
[docs]
list_of_edges = numba.types.ListType(edge_type)
# Type for storing clause data
[docs]
clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string)))
# Type for storing refine clause data
[docs]
refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8))
# Type for facts to be applied
[docs]
facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
[docs]
facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
# Type for returning list of applicable rules for a certain rule
# node/edge, annotations, qualified nodes, qualified edges, edges to be added,
# clause labels (per-clause predicate label), clause variables (per-clause variable names)
[docs]
node_applicable_rule_type = numba.types.Tuple((
node_type,
numba.types.ListType(numba.types.ListType(interval.interval_type)),
numba.types.ListType(numba.types.ListType(node_type)),
numba.types.ListType(numba.types.ListType(edge_type)),
numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)),
numba.types.ListType(label.label_type),
numba.types.ListType(numba.types.ListType(numba.types.string))
))
[docs]
edge_applicable_rule_type = numba.types.Tuple((
edge_type,
numba.types.ListType(numba.types.ListType(interval.interval_type)),
numba.types.ListType(numba.types.ListType(node_type)),
numba.types.ListType(numba.types.ListType(edge_type)),
numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)),
numba.types.ListType(label.label_type),
numba.types.ListType(numba.types.ListType(numba.types.string))
))
[docs]
rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean))
[docs]
rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean))
[docs]
rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))
[docs]
edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
[docs]
class Interpretation:
[docs]
specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type))
[docs]
specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type))
[docs]
closed_world_predicates = numba.typed.List.empty_list(label.label_type)
def __init__(self, graph, ipl, annotation_functions, head_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules):
[docs]
self.annotation_functions = annotation_functions
[docs]
self.head_functions = head_functions
[docs]
self.reverse_graph = reverse_graph
[docs]
self.atom_trace = atom_trace
[docs]
self.save_graph_attributes_to_rule_trace = save_graph_attributes_to_rule_trace
[docs]
self.persistent = persistent
[docs]
self.inconsistency_check = inconsistency_check
[docs]
self.store_interpretation_changes = store_interpretation_changes
[docs]
self.update_mode = update_mode
[docs]
self.allow_ground_rules = allow_ground_rules
# Counter for number of ground atoms for each timestep, start with zero for the zeroth timestep
[docs]
self.num_ga = numba.typed.List.empty_list(numba.types.int64)
self.num_ga.append(0)
# For reasoning and reasoning again (contains previous time and previous fp operation cnt)
[docs]
self.prev_reasoning_data = numba.typed.List([0, 0])
# Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true
[docs]
self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
[docs]
self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
[docs]
self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string)
[docs]
self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string)
[docs]
self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type)
[docs]
self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type)
[docs]
self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type)
[docs]
self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
[docs]
self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)))
[docs]
self.edges_to_be_added_edge_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)))
# Keep track of all the rules that have affected each node/edge at each timestep/fp operation, and all ground atoms that have affected the rules as well. Keep track of previous bounds and name of the rule/fact here
[docs]
self.rule_trace_node_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string)))
[docs]
self.rule_trace_edge_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string)))
[docs]
self.rule_trace_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.string, numba.types.string, numba.types.string)))
[docs]
self.rule_trace_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.string, numba.types.string, numba.types.string)))
# Nodes and edges of the graph
[docs]
self.nodes = numba.typed.List.empty_list(node_type)
[docs]
self.edges = numba.typed.List.empty_list(edge_type)
self.nodes.extend(numba.typed.List(self.graph.nodes()))
self.edges.extend(numba.typed.List(self.graph.edges()))
self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.specific_node_labels, self.num_ga)
self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.specific_edge_labels, self.num_ga)
# Setup graph neighbors and reverse neighbors
[docs]
self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type))
for n in self.graph.nodes():
l = numba.typed.List.empty_list(node_type)
[l.append(neigh) for neigh in self.graph.neighbors(n)]
self.neighbors[n] = l
[docs]
self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors)
@staticmethod
@numba.njit(cache=True)
def _init_reverse_neighbors(neighbors):
reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
for n, neighbor_nodes in neighbors.items():
for neighbor_node in neighbor_nodes:
if neighbor_node in reverse_neighbors and n not in reverse_neighbors[neighbor_node]:
reverse_neighbors[neighbor_node].append(n)
else:
reverse_neighbors[neighbor_node] = numba.typed.List([n])
# This makes sure each node has a value
if n not in reverse_neighbors:
reverse_neighbors[n] = numba.typed.List.empty_list(node_type)
return reverse_neighbors
@staticmethod
@numba.njit(cache=True)
def _init_interpretations_node(nodes, specific_labels, num_ga):
interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type)
predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes)
# Initialize nodes
for n in nodes:
interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type))
# Specific labels
for l, ns in specific_labels.items():
predicate_map[l] = numba.typed.List(ns)
for n in ns:
interpretations[n].world[l] = interval.closed(0.0, 1.0)
num_ga[0] += 1
return interpretations, predicate_map
@staticmethod
@numba.njit(cache=True)
def _init_interpretations_edge(edges, specific_labels, num_ga):
interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type)
predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges)
# Initialize edges
for n in edges:
interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type))
# Specific labels
for l, es in specific_labels.items():
predicate_map[l] = numba.typed.List(es)
for e in es:
interpretations[e].world[l] = interval.closed(0.0, 1.0)
num_ga[0] += 1
return interpretations, predicate_map
@staticmethod
@numba.njit(cache=True)
def _init_convergence(convergence_bound_threshold, convergence_threshold):
if convergence_bound_threshold==-1 and convergence_threshold==-1:
convergence_mode = 'perfect_convergence'
convergence_delta = 0
elif convergence_bound_threshold==-1:
convergence_mode = 'delta_interpretation'
convergence_delta = convergence_threshold
else:
convergence_mode = 'delta_bound'
convergence_delta = convergence_bound_threshold
return convergence_mode, convergence_delta
[docs]
def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False, restart=True):
self.tmax = tmax
self._convergence_mode, self._convergence_delta = self._init_convergence(convergence_bound_threshold, convergence_threshold)
max_facts_time = self._init_facts(facts_node, facts_edge, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.atom_trace)
self._start_fp(rules, max_facts_time, verbose, again, restart)
@staticmethod
@numba.njit(cache=True)
def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, atom_trace):
max_time = 0
for fact in facts_node:
for t in range(fact.get_time_lower(), fact.get_time_upper() + 1):
max_time = max(max_time, t)
name = fact.get_name()
graph_attribute = True if name=='graph-attribute-fact' else False
facts_to_be_applied_node.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute))
if atom_trace:
facts_to_be_applied_node_trace.append(fact.get_name())
for fact in facts_edge:
for t in range(fact.get_time_lower(), fact.get_time_upper() + 1):
max_time = max(max_time, t)
name = fact.get_name()
graph_attribute = True if name=='graph-attribute-fact' else False
facts_to_be_applied_edge.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute))
if atom_trace:
facts_to_be_applied_edge_trace.append(fact.get_name())
return max_time
def _start_fp(self, rules, max_facts_time, verbose, again, restart):
if again:
self.num_ga.append(self.num_ga[-1])
if restart:
self.time = 0
self.prev_reasoning_data[0] = 0
# Per-rule flag: True iff the rule's annotation function is registered
# with the extended 6-arg signature. Gates the per-grounding metadata
# build in _ground_rule so legacy 2-arg ann_fns pay zero extra cost.
ann_fn_arity = {f.__name__: getattr(f, 'py_func', f).__code__.co_argcount for f in self.annotation_functions}
extended_ann_fn_flags = numba.typed.List.empty_list(numba.types.boolean)
for r in rules:
fn_name = r.get_annotation_function()
extended_ann_fn_flags.append(fn_name != '' and ann_fn_arity.get(fn_name, 0) == 6)
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, extended_ann_fn_flags, self.head_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again, self.closed_world_predicates)
self.time = t - 1
# If we need to reason again, store the next timestep to start from
self.prev_reasoning_data[0] = t
self.prev_reasoning_data[1] = fp_cnt
if verbose:
print('Fixed Point iterations:', fp_cnt)
@staticmethod
@numba.njit(cache=True, parallel=False)
[docs]
def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, extended_ann_fn_flags, head_functions, convergence_mode, convergence_delta, num_ga, verbose, again, closed_world_predicates):
t = prev_reasoning_data[0]
fp_cnt = prev_reasoning_data[1]
max_rules_time = 0
timestep_loop = True
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string)
facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string)
rules_to_remove_idx = set()
rules_to_remove_idx.add(-1)
while timestep_loop:
if t==tmax:
timestep_loop = False
if verbose:
with objmode():
print('Timestep:', t, flush=True)
# Reset Interpretation at beginning of timestep if non-persistent
if t>0 and not persistent:
# Reset nodes (only if not static)
for n in nodes:
w = interpretations_node[n].world
for l in w:
if not w[l].is_static():
w[l].reset()
# Reset edges (only if not static)
for e in edges:
w = interpretations_edge[e].world
for l in w:
if not w[l].is_static():
w[l].reset()
# Convergence parameters
changes_cnt = 0
bound_delta = 0
update = False
# Start by applying facts
# Nodes
facts_to_be_applied_node_new.clear()
facts_to_be_applied_node_trace_new.clear()
nodes_set = set(nodes)
for i in range(len(facts_to_be_applied_node)):
if facts_to_be_applied_node[i][0] == t:
comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5]
# If the component is not in the graph, add it
if comp not in nodes_set:
_add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node)
nodes_set.add(comp)
# Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well
if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static():
# Check if we should even store any of the changes to the rule trace etc.
# Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute
if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes:
meta_name = facts_to_be_applied_node_trace[i] if atom_trace else ''
rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd, True, 'Fact', meta_name, ''))
if atom_trace:
_update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i])
for p1, p2 in ipl:
if p1==l:
rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2], True, 'IPL', f'IPL: {l.get_value()}', ''))
if atom_trace:
_update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i])
elif p2==l:
rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1], True, 'IPL', f'IPL: {l.get_value()}', ''))
if atom_trace:
_update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i])
else:
# Check for inconsistencies (multiple facts)
if check_consistent_node(interpretations_node, comp, (l, bnd)):
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
override = True if update_mode == 'override' else False
u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Resolve inconsistency if necessary otherwise override bounds
else:
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
if inconsistency_check:
resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode)
else:
u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
if static:
facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute))
if atom_trace:
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])
# If time doesn't match, fact to be applied later
else:
facts_to_be_applied_node_new.append(facts_to_be_applied_node[i])
if atom_trace:
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])
# Update list of facts with ones that have not been applied yet (delete applied facts)
facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy()
if atom_trace:
facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy()
facts_to_be_applied_node_new.clear()
facts_to_be_applied_node_trace_new.clear()
# Edges
facts_to_be_applied_edge_new.clear()
facts_to_be_applied_edge_trace_new.clear()
edges_set = set(edges)
for i in range(len(facts_to_be_applied_edge)):
if facts_to_be_applied_edge[i][0]==t:
comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5]
# If the component is not in the graph, add it
if comp not in edges_set:
_add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
edges_set.add(comp)
# Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well
if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static():
# Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute
if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes:
meta_name = facts_to_be_applied_edge_trace[i] if atom_trace else ''
rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l], True, 'Fact', meta_name, ''))
if atom_trace:
_update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i])
for p1, p2 in ipl:
if p1==l:
rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2], True, 'IPL', f'IPL: {l.get_value()}', ''))
if atom_trace:
_update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i])
elif p2==l:
rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1], True, 'IPL', f'IPL: {l.get_value()}', ''))
if atom_trace:
_update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i])
else:
# Check for inconsistencies
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
override = True if update_mode == 'override' else False
u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Resolve inconsistency
else:
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
if inconsistency_check:
resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode)
else:
u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
if static:
facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute))
if atom_trace:
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])
# Time doesn't match, fact to be applied later
else:
facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i])
if atom_trace:
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])
# Update list of facts with ones that have not been applied yet (delete applied facts)
facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy()
if atom_trace:
facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy()
facts_to_be_applied_edge_new.clear()
facts_to_be_applied_edge_trace_new.clear()
in_loop = True
while in_loop:
# This will become true only if delta_t = 0 for some rule, otherwise we go to the next timestep
in_loop = False
# Apply the rules that need to be applied at this timestep
# Nodes
rules_to_remove_idx.clear()
for idx, i in enumerate(rules_to_be_applied_node):
if i[0] == t:
comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
# Check for inconsistencies
if check_consistent_node(interpretations_node, comp, (l, bnd)):
override = True if update_mode == 'override' else False
u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Resolve inconsistency
else:
if inconsistency_check:
resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule')
else:
u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Delete rules that have been applied from list by adding index to list
rules_to_remove_idx.add(idx)
# Remove from rules to be applied and edges to be applied lists after coming out from loop
rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx])
edges_to_be_added_node_rule[:] = numba.typed.List([edges_to_be_added_node_rule[i] for i in range(len(edges_to_be_added_node_rule)) if i not in rules_to_remove_idx])
if atom_trace:
rules_to_be_applied_node_trace[:] = numba.typed.List([rules_to_be_applied_node_trace[i] for i in range(len(rules_to_be_applied_node_trace)) if i not in rules_to_remove_idx])
# Edges
rules_to_remove_idx.clear()
for idx, i in enumerate(rules_to_be_applied_edge):
if i[0] == t:
comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
sources, targets, edge_l = edges_to_be_added_edge_rule[idx]
edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
changes_cnt += changes
# Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
if edge_l.value != '':
for e in edges_added:
if interpretations_edge[e].world[edge_l].is_static():
continue
if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
override = True if update_mode == 'override' else False
u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Resolve inconsistency
else:
if inconsistency_check:
resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
else:
u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
else:
# Check for inconsistencies
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
override = True if update_mode == 'override' else False
u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Resolve inconsistency
else:
if inconsistency_check:
resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
else:
u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
if convergence_mode=='delta_bound':
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
# Delete rules that have been applied from list by adding the index to list
rules_to_remove_idx.add(idx)
# Remove from rules to be applied and edges to be applied lists after coming out from loop
rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx])
edges_to_be_added_edge_rule[:] = numba.typed.List([edges_to_be_added_edge_rule[i] for i in range(len(edges_to_be_added_edge_rule)) if i not in rules_to_remove_idx])
if atom_trace:
rules_to_be_applied_edge_trace[:] = numba.typed.List([rules_to_be_applied_edge_trace[i] for i in range(len(rules_to_be_applied_edge_trace)) if i not in rules_to_remove_idx])
# Fixed point
if update:
# Increase fp operator count
fp_cnt += 1
# Lists or threadsafe operations (when parallel is on)
rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))])
rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))])
if atom_trace:
rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))])
# Threadsafe flags for in_loop and update within prange; merge after loop
in_loop_threadsafe = numba.typed.List.empty_list(numba.types.boolean)
update_threadsafe = numba.typed.List.empty_list(numba.types.boolean)
for _ in range(len(rules)):
in_loop_threadsafe.append(False)
update_threadsafe.append(True)
for i in prange(len(rules)):
rule = rules[i]
# Only go through if the rule can be applied within the given timesteps, or we're running until convergence
delta_t = rule.get_delta()
if t + delta_t <= tmax or tmax == -1 or again:
applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, extended_ann_fn_flags[i], allow_ground_rules, num_ga, t, head_functions, closed_world_predicates)
# Loop through applicable rules and add them to the rules to be applied for later or next fp operation
for applicable_rule in applicable_node_rules:
n, annotations, qualified_nodes, qualified_edges, _, clause_labels, clause_variables = applicable_rule
# If there is an edge to add or the predicate doesn't exist or the interpretation is not static
if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static():
bnd = annotate(annotation_functions, rule, annotations, qualified_nodes, qualified_edges, clause_labels, clause_variables, rule.get_weights())
# If the rule head was negated AND an ann_fn produced the bound, invert it:
# ~[l,u] = [1-u, 1-l]. For non-ann_fn negation the parser already folded
# the inversion into target_bound, so we must NOT re-invert here.
if rule.get_annotation_function() != '' and rule.is_head_negated():
bnd = (1 - bnd[1], 1 - bnd[0])
# Bound annotations in between 0 and 1
bnd_l = min(max(bnd[0], 0), 1)
bnd_u = min(max(bnd[1], 0), 1)
bnd = interval.closed(bnd_l, bnd_u)
max_rules_time = max(max_rules_time, t + delta_t)
rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, rule.is_static_rule()))
if atom_trace:
rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
# If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop_threadsafe[i] = True
update_threadsafe[i] = False
for applicable_rule in applicable_edge_rules:
e, annotations, qualified_nodes, qualified_edges, edges_to_add, clause_labels, clause_variables = applicable_rule
# If there is an edge to add or the predicate doesn't exist or the interpretation is not static
if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_edge[e].world or not interpretations_edge[e].world[rule.get_target()].is_static():
bnd = annotate(annotation_functions, rule, annotations, qualified_nodes, qualified_edges, clause_labels, clause_variables, rule.get_weights())
# If the rule head was negated AND an ann_fn produced the bound, invert it:
# ~[l,u] = [1-u, 1-l]. For non-ann_fn negation the parser already folded
# the inversion into target_bound, so we must NOT re-invert here.
if rule.get_annotation_function() != '' and rule.is_head_negated():
bnd = (1 - bnd[1], 1 - bnd[0])
# Bound annotations in between 0 and 1
bnd_l = min(max(bnd[0], 0), 1)
bnd_u = min(max(bnd[1], 0), 1)
bnd = interval.closed(bnd_l, bnd_u)
max_rules_time = max(max_rules_time, t+delta_t)
# edges_to_be_added_edge_rule.append(edges_to_add)
edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add)
rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, rule.is_static_rule()))
if atom_trace:
# rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
# If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop_threadsafe[i] = True
update_threadsafe[i] = False
# Update lists after parallel run
for i in range(len(rules)):
if len(rules_to_be_applied_node_threadsafe[i]) > 0:
rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i])
if len(rules_to_be_applied_edge_threadsafe[i]) > 0:
rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i])
if atom_trace:
if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0:
rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i])
if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0:
rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i])
if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0:
edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i])
# Merge threadsafe flags for in_loop and update
in_loop = in_loop
update = update
for i in range(len(rules)):
if in_loop_threadsafe[i]:
in_loop = True
if not update_threadsafe[i]:
update = False
# Check for convergence after each timestep (perfect convergence or convergence specified by user)
# Check number of changed interpretations or max bound change
# User specified convergence
if convergence_mode == 'delta_interpretation':
if changes_cnt <= convergence_delta:
if verbose:
print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation')
# Be consistent with time returned when we don't converge
t += 1
break
elif convergence_mode == 'delta_bound':
if bound_delta <= convergence_delta:
if verbose:
print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation')
# Be consistent with time returned when we don't converge
t += 1
break
# Perfect convergence
# Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable
# If no more rules/facts to be applied
elif convergence_mode == 'perfect_convergence':
if t>=max_facts_time and t >= max_rules_time:
if verbose:
print(f'\nConverged at time: {t}')
# Be consistent with time returned when we don't converge
t += 1
break
# Increment t, update number of ground atoms
t += 1
num_ga.append(num_ga[-1])
return fp_cnt, t
[docs]
def add_edge(self, edge, l):
# This function is useful for pyreason gym, called externally
_add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1)
[docs]
def add_node(self, node, labels):
# This function is useful for pyreason gym, called externally
if node not in self.nodes:
_add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node)
for l in labels:
self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1)
[docs]
def delete_edge(self, edge):
# This function is useful for pyreason gym, called externally
_delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge, self.num_ga)
[docs]
def delete_node(self, node):
# This function is useful for pyreason gym, called externally
_delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node, self.num_ga)
[docs]
def get_dict(self):
# This function can be called externally to retrieve a dict of the interpretation values
# Only values in the rule trace will be added
# Initialize interpretations for each time and node and edge
interpretations = {}
for t in range(self.time+1):
interpretations[t] = {}
for node in self.nodes:
interpretations[t][node] = InterpretationDict()
for edge in self.edges:
interpretations[t][edge] = InterpretationDict()
# Update interpretation nodes
for change in self.rule_trace_node:
time, _, node, l, bnd, consistent, triggered_by, name, inconsistency_msg = change
interpretations[time][node][l._value] = (bnd.lower, bnd.upper)
# If persistent, update all following timesteps as well
if self. persistent:
for t in range(time+1, self.time+1):
interpretations[t][node][l._value] = (bnd.lower, bnd.upper)
# Update interpretation edges
for change in self.rule_trace_edge:
time, _, edge, l, bnd, consistent, triggered_by, name, inconsistency_msg = change
interpretations[time][edge][l._value] = (bnd.lower, bnd.upper)
# If persistent, update all following timesteps as well
if self. persistent:
for t in range(time+1, self.time+1):
interpretations[t][edge][l._value] = (bnd.lower, bnd.upper)
return interpretations
[docs]
def get_final_num_ground_atoms(self):
"""
This function returns the number of ground atoms after the reasoning process, for the final timestep
:return: int: Number of ground atoms in the interpretation after reasoning
"""
ga_cnt = 0
for node in self.nodes:
for l in self.interpretations_node[node].world:
ga_cnt += 1
for edge in self.edges:
for l in self.interpretations_edge[edge].world:
ga_cnt += 1
return ga_cnt
[docs]
def get_num_ground_atoms(self):
"""
This function returns the number of ground atoms after the reasoning process, for each timestep
:return: list: Number of ground atoms in the interpretation after reasoning for each timestep
"""
if self.num_ga[-1] == 0:
self.num_ga.pop()
return self.num_ga
[docs]
def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]:
"""
This function is used to query the graph after reasoning
:param query: A PyReason query object
:param return_bool: If True, returns boolean of query, else the bounds associated with it
:return: bool, or bounds
"""
comp_type = query.get_component_type()
component = query.get_component()
pred = query.get_predicate()
bnd = query.get_bounds()
# Check if the component exists
if comp_type == 'node':
if component not in self.nodes:
return False if return_bool else (0, 1)
else:
if component not in self.edges:
return False if return_bool else (0, 1)
# Check if the predicate exists
if comp_type == 'node':
if pred not in self.interpretations_node[component].world:
return False if return_bool else (0, 1)
else:
if pred not in self.interpretations_edge[component].world:
return False if return_bool else (0, 1)
# Check if the bounds are satisfied
if comp_type == 'node':
if self.interpretations_node[component].world[pred] in bnd:
return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper)
else:
return False if return_bool else (0, 0)
else:
if self.interpretations_edge[component].world[pred] in bnd:
return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper)
else:
return False if return_bool else (0, 0)
@numba.njit(cache=True)
def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, extended_ann_fn, allow_ground_rules, num_ga, t, head_functions, closed_world_predicates):
# Extract rule params
rule_type = rule.get_type()
head_variables = rule.get_head_variables()
head_fns = rule.get_head_function()
head_fns_vars = rule.get_head_function_vars()
clauses = rule.get_clauses()
thresholds = rule.get_thresholds()
ann_fn = rule.get_annotation_function()
rule_edges = rule.get_edges()
if rule_type == 'node':
head_var_1 = head_variables[0]
else:
head_var_1, head_var_2 = head_variables[0], head_variables[1]
# We return a list of tuples which specify the target nodes/edges that have made the rule body true
applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type)
applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type)
# Grounding procedure
# 1. Go through each clause and check which variables have not been initialized in groundings
# 2. Check satisfaction of variables based on the predicate in the clause
# Grounding variable that maps variables in the body to a list of grounded nodes
# Grounding edges that maps edge variables to a list of edges
groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges)
# Dependency graph that keeps track of the connections between the variables in the body
dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
nodes_set = set(nodes)
edges_set = set(edges)
satisfaction = True
for i, clause in enumerate(clauses):
# Unpack clause variables
clause_type = clause[0]
clause_label = clause[1]
clause_variables = clause[2]
clause_bnd = clause[3]
_clause_operator = clause[4]
# This is a node clause
if clause_type == 'node':
clause_var_1 = clause_variables[0]
# Get subset of nodes that can be used to ground the variable
# If we allow ground atoms, we can use the nodes directly
if allow_ground_rules and clause_var_1 in nodes_set:
grounding = numba.typed.List([clause_var_1])
else:
grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes)
# Narrow subset based on predicate
qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd, closed_world_predicates)
groundings[clause_var_1] = qualified_groundings
qualified_groundings_set = set(qualified_groundings)
for c1, c2 in groundings_edges:
if c1 == clause_var_1:
groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set])
if c2 == clause_var_1:
groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set])
# Check satisfaction of those nodes wrt the threshold
# Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
# It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
# if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i], closed_world_predicates) and satisfaction
# This is an edge clause
elif clause_type == 'edge':
clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
# Get subset of edges that can be used to ground the variables
# If we allow ground atoms, we can use the nodes directly
if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
grounding = numba.typed.List([(clause_var_1, clause_var_2)])
else:
# Pre-populate groundings for any variable that matches an existing node (partial grounding)
if allow_ground_rules:
if clause_var_1 in nodes_set and clause_var_1 not in groundings:
groundings[clause_var_1] = numba.typed.List([clause_var_1])
if clause_var_2 in nodes_set and clause_var_2 not in groundings:
groundings[clause_var_2] = numba.typed.List([clause_var_2])
grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges)
# Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)
qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd, closed_world_predicates)
# Check satisfaction of those edges wrt the threshold
# Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
# It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
# if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i], closed_world_predicates) and satisfaction
# Update the groundings
groundings[clause_var_1] = numba.typed.List.empty_list(node_type)
groundings[clause_var_2] = numba.typed.List.empty_list(node_type)
groundings_clause_1_set = set(groundings[clause_var_1])
groundings_clause_2_set = set(groundings[clause_var_2])
for e in qualified_groundings:
if e[0] not in groundings_clause_1_set:
groundings[clause_var_1].append(e[0])
groundings_clause_1_set.add(e[0])
if e[1] not in groundings_clause_2_set:
groundings[clause_var_2].append(e[1])
groundings_clause_2_set.add(e[1])
# Update the edge groundings (to use later for grounding other clauses with the same variables)
groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings
# Update dependency graph
# Add a connection between clause_var_1 -> clause_var_2 and vice versa
if clause_var_1 not in dependency_graph_neighbors:
dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2])
elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]:
dependency_graph_neighbors[clause_var_1].append(clause_var_2)
if clause_var_2 not in dependency_graph_reverse_neighbors:
dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1])
elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]:
dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1)
# This is a comparison clause
else:
pass
# Refine the subsets based on any updates
if satisfaction:
refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
# If satisfaction is false, break
if not satisfaction:
break
# If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules
# Then continue to setup any edges to be added and annotations
# Fill out the rules to be applied lists
if satisfaction:
# Create temp grounding containers to verify if the head groundings are valid (only for edge rules)
# Setup edges to be added and fill rules to be applied
# Setup traces and inputs for annotation function
# Loop through the clause data and setup final annotations and trace variables
# Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule
if rule_type == 'node':
# Loop through all the head variable groundings and add it to the rules to be applied
# Loop through the clauses and add appropriate trace data and annotations
# Apply any function in the head to determine the head grounding
head_var_groundings, is_func = _determine_node_head_vars(head_fns, head_fns_vars, groundings, head_functions)
if is_func:
groundings[head_var_1] = head_var_groundings
# If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph
head_var_1_in_nodes = head_var_1 in nodes
add_head_var_node_to_graph = False
if allow_ground_rules and head_var_1_in_nodes:
groundings[head_var_1] = numba.typed.List([head_var_1])
elif head_var_1 not in groundings:
if not head_var_1_in_nodes:
add_head_var_node_to_graph = True
groundings[head_var_1] = numba.typed.List([head_var_1])
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
# is allocated K*N times (K head groundings * N clauses) below, but the data
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
# here and append references in the inner loop instead of fresh copies.
for head_grounding in groundings[head_var_1]:
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
clause_labels_out = numba.typed.List.empty_list(label.label_type)
clause_variables_out = numba.typed.List.empty_list(numba.typed.List.empty_list(numba.types.string))
edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
# Check for satisfaction one more time in case the refining process has changed the groundings
satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges, closed_world_predicates)
if not satisfaction:
continue
for i, clause in enumerate(clauses):
clause_type = clause[0]
clause_label = clause[1]
clause_variables = clause[2]
if clause_type == 'node':
clause_var_1 = clause_variables[0]
# 1.
if atom_trace or extended_ann_fn:
if clause_var_1 == head_var_1:
qualified_nodes.append(numba.typed.List([head_grounding]))
else:
qualified_nodes.append(numba.typed.List(groundings[clause_var_1]))
qualified_edges.append(numba.typed.List.empty_list(edge_type))
if extended_ann_fn:
clause_labels_out.append(clause_label)
clause_variables_out.append(numba.typed.List(clause_variables))
# 2.
if ann_fn != '':
a = numba.typed.List.empty_list(interval.interval_type)
if clause_var_1 == head_var_1:
a.append(interpretations_node[head_grounding].world[clause_label])
else:
for qn in groundings[clause_var_1]:
a.append(interpretations_node[qn].world[clause_label])
annotations.append(a)
elif clause_type == 'edge':
clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
# 1.
if atom_trace or extended_ann_fn:
# Cases: Both equal, one equal, none equal
qualified_nodes.append(numba.typed.List.empty_list(node_type))
if clause_var_1 == head_var_1:
es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding])
qualified_edges.append(es)
elif clause_var_2 == head_var_1:
es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding])
qualified_edges.append(es)
else:
qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)]))
if extended_ann_fn:
clause_labels_out.append(clause_label)
clause_variables_out.append(numba.typed.List(clause_variables))
# 2.
if ann_fn != '':
a = numba.typed.List.empty_list(interval.interval_type)
if clause_var_1 == head_var_1:
for e in groundings_edges[(clause_var_1, clause_var_2)]:
if e[0] == head_grounding:
a.append(interpretations_edge[e].world[clause_label])
elif clause_var_2 == head_var_1:
for e in groundings_edges[(clause_var_1, clause_var_2)]:
if e[1] == head_grounding:
a.append(interpretations_edge[e].world[clause_label])
else:
for qe in groundings_edges[(clause_var_1, clause_var_2)]:
a.append(interpretations_edge[qe].world[clause_label])
annotations.append(a)
else:
# Comparison clause (we do not handle for now)
pass
# Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
if add_head_var_node_to_graph:
_add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
# For each grounding add a rule to be applied
applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added, clause_labels_out, clause_variables_out))
elif rule_type == 'edge':
head_var_1 = head_variables[0]
head_var_2 = head_variables[1]
# Apply any function in the head to determine the head grounding
head_var_groundings, is_func = _determine_edge_head_vars(head_fns, head_fns_vars, groundings, head_functions)
if is_func[0]:
groundings[head_var_1] = head_var_groundings[0]
if is_func[1]:
groundings[head_var_2] = head_var_groundings[1]
# If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph
head_var_1_in_nodes = head_var_1 in nodes
head_var_2_in_nodes = head_var_2 in nodes
add_head_var_1_node_to_graph = False
add_head_var_2_node_to_graph = False
add_head_edge_to_graph = False
if allow_ground_rules and head_var_1_in_nodes:
groundings[head_var_1] = numba.typed.List([head_var_1])
if allow_ground_rules and head_var_2_in_nodes:
groundings[head_var_2] = numba.typed.List([head_var_2])
if head_var_1 not in groundings:
if not head_var_1_in_nodes:
add_head_var_1_node_to_graph = True
groundings[head_var_1] = numba.typed.List([head_var_1])
if head_var_2 not in groundings:
if not head_var_2_in_nodes:
add_head_var_2_node_to_graph = True
groundings[head_var_2] = numba.typed.List([head_var_2])
# Artificially connect the head variables with an edge if both of them were not in the graph
if not head_var_1_in_nodes and not head_var_2_in_nodes:
add_head_edge_to_graph = True
head_var_1_groundings = groundings[head_var_1]
head_var_2_groundings = groundings[head_var_2]
source, target, _ = rule_edges
infer_edges = True if source != '' and target != '' else False
# Prepare the edges that we will loop over.
# For infer edges we loop over each combination pair
# Else we loop over the valid edges in the graph
valid_edge_groundings = numba.typed.List.empty_list(edge_type)
for g1 in head_var_1_groundings:
for g2 in head_var_2_groundings:
if infer_edges:
valid_edge_groundings.append((g1, g2))
else:
if (g1, g2) in edges_set:
valid_edge_groundings.append((g1, g2))
# Loop through the head variable groundings
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
# is allocated K*N times (K edge groundings * N clauses) below, but the data
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
# here and append references in the inner loop instead of fresh copies.
for valid_e in valid_edge_groundings:
head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
clause_labels_out = numba.typed.List.empty_list(label.label_type)
clause_variables_out = numba.typed.List.empty_list(numba.typed.List.empty_list(numba.types.string))
edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
# Containers to keep track of groundings to make sure that the edge pair is valid
# We do this because we cannot know beforehand the edge matches from source groundings to target groundings
temp_groundings = groundings.copy()
temp_groundings_edges = groundings_edges.copy()
# Refine the temp groundings for the specific edge head grounding
# We update the edge collection as well depending on if there's a match between the clause variables and head variables
temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding])
temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding])
for c1, c2 in temp_groundings_edges.keys():
if c1 == head_var_1 and c2 == head_var_2:
temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)])
elif c1 == head_var_2 and c2 == head_var_1:
temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)])
elif c1 == head_var_1:
temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding])
elif c2 == head_var_1:
temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding])
elif c1 == head_var_2:
temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding])
elif c2 == head_var_2:
temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding])
refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
# Check if the thresholds are still satisfied
# Check if all clauses are satisfied again in case the refining process changed anything
satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges, closed_world_predicates)
if not satisfaction:
continue
if infer_edges:
# Prevent self loops while inferring edges if the clause variables are not the same
if source != target and head_var_1_grounding == head_var_2_grounding:
continue
edges_to_be_added[0].append(head_var_1_grounding)
edges_to_be_added[1].append(head_var_2_grounding)
for i, clause in enumerate(clauses):
clause_type = clause[0]
clause_label = clause[1]
clause_variables = clause[2]
if clause_type == 'node':
clause_var_1 = clause_variables[0]
# 1.
if atom_trace or extended_ann_fn:
if clause_var_1 == head_var_1:
qualified_nodes.append(numba.typed.List([head_var_1_grounding]))
elif clause_var_1 == head_var_2:
qualified_nodes.append(numba.typed.List([head_var_2_grounding]))
else:
qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1]))
qualified_edges.append(numba.typed.List.empty_list(edge_type))
if extended_ann_fn:
clause_labels_out.append(clause_label)
clause_variables_out.append(numba.typed.List(clause_variables))
# 2.
if ann_fn != '':
a = numba.typed.List.empty_list(interval.interval_type)
if clause_var_1 == head_var_1:
a.append(interpretations_node[head_var_1_grounding].world[clause_label])
elif clause_var_1 == head_var_2:
a.append(interpretations_node[head_var_2_grounding].world[clause_label])
else:
for qn in temp_groundings[clause_var_1]:
a.append(interpretations_node[qn].world[clause_label])
annotations.append(a)
elif clause_type == 'edge':
clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
# 1.
if atom_trace or extended_ann_fn:
# Cases:
# 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1)
# 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2)
# 3. None equal
qualified_nodes.append(numba.typed.List.empty_list(node_type))
if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding])
qualified_edges.append(es)
elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding])
qualified_edges.append(es)
elif clause_var_1 == head_var_1:
es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding])
qualified_edges.append(es)
elif clause_var_1 == head_var_2:
es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding])
qualified_edges.append(es)
elif clause_var_2 == head_var_1:
es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding])
qualified_edges.append(es)
elif clause_var_2 == head_var_2:
es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding])
qualified_edges.append(es)
else:
qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)]))
if extended_ann_fn:
clause_labels_out.append(clause_label)
clause_variables_out.append(numba.typed.List(clause_variables))
# 2.
if ann_fn != '':
a = numba.typed.List.empty_list(interval.interval_type)
if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding:
a.append(interpretations_edge[e].world[clause_label])
elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding:
a.append(interpretations_edge[e].world[clause_label])
elif clause_var_1 == head_var_1:
for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
if e[0] == head_var_1_grounding:
a.append(interpretations_edge[e].world[clause_label])
elif clause_var_1 == head_var_2:
for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
if e[0] == head_var_2_grounding:
a.append(interpretations_edge[e].world[clause_label])
elif clause_var_2 == head_var_1:
for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
if e[1] == head_var_1_grounding:
a.append(interpretations_edge[e].world[clause_label])
elif clause_var_2 == head_var_2:
for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
if e[1] == head_var_2_grounding:
a.append(interpretations_edge[e].world[clause_label])
else:
for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]:
a.append(interpretations_edge[qe].world[clause_label])
annotations.append(a)
# Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1:
_add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2:
_add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node)
if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding):
_add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
# For each grounding combination add a rule to be applied
# Only if all the clauses have valid groundings
# if satisfaction:
e = (head_var_1_grounding, head_var_2_grounding)
applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added, clause_labels_out, clause_variables_out))
# Return the applicable rules
return applicable_rules_node, applicable_rules_edge
@numba.njit(cache=True)
[docs]
def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges, closed_world_predicates):
# Check if the thresholds are satisfied for each clause
satisfaction = True
for i, clause in enumerate(clauses):
# Unpack clause variables
clause_type = clause[0]
clause_label = clause[1]
clause_variables = clause[2]
clause_bnd = clause[3]
if clause_type == 'node':
clause_var_1 = clause_variables[0]
qualified_groundings = get_qualified_node_groundings(interpretations_node, groundings[clause_var_1], clause_label, clause_bnd, closed_world_predicates)
satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], qualified_groundings, clause_label, thresholds[i], closed_world_predicates) and satisfaction
elif clause_type == 'edge':
clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
qualified_groundings = get_qualified_edge_groundings(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], clause_label, clause_bnd, closed_world_predicates)
satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], qualified_groundings, clause_label, thresholds[i], closed_world_predicates) and satisfaction
return satisfaction
@numba.njit(cache=True)
[docs]
def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors):
# Loop through the dependency graph and refine the groundings that have connections
all_variables_refined = numba.typed.List(clause_variables)
variables_just_refined = numba.typed.List(clause_variables)
new_variables_refined = numba.typed.List.empty_list(numba.types.string)
while len(variables_just_refined) > 0:
for refined_variable in variables_just_refined:
# Refine all the neighbors of the refined variable
if refined_variable in dependency_graph_neighbors:
for neighbor in dependency_graph_neighbors[refined_variable]:
old_edge_groundings = groundings_edges[(refined_variable, neighbor)]
new_node_groundings = groundings[refined_variable]
# Delete old groundings for the variable being refined
del groundings[neighbor]
groundings[neighbor] = numba.typed.List.empty_list(node_type)
# Update the edge groundings and node groundings
qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings])
groundings_neighbor_set = set(groundings[neighbor])
for e in qualified_groundings:
if e[1] not in groundings_neighbor_set:
groundings[neighbor].append(e[1])
groundings_neighbor_set.add(e[1])
groundings_edges[(refined_variable, neighbor)] = qualified_groundings
# Add the neighbor to the list of refined variables so that we can refine for all its neighbors
if neighbor not in all_variables_refined:
new_variables_refined.append(neighbor)
if refined_variable in dependency_graph_reverse_neighbors:
for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]:
old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)]
new_node_groundings = groundings[refined_variable]
# Delete old groundings for the variable being refined
del groundings[reverse_neighbor]
groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type)
# Update the edge groundings and node groundings
qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings])
groundings_reverse_neighbor_set = set(groundings[reverse_neighbor])
for e in qualified_groundings:
if e[0] not in groundings_reverse_neighbor_set:
groundings[reverse_neighbor].append(e[0])
groundings_reverse_neighbor_set.add(e[0])
groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings
# Add the neighbor to the list of refined variables so that we can refine for all its neighbors
if reverse_neighbor not in all_variables_refined:
new_variables_refined.append(reverse_neighbor)
variables_just_refined = numba.typed.List(new_variables_refined)
all_variables_refined.extend(new_variables_refined)
new_variables_refined.clear()
@numba.njit(cache=True)
[docs]
def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold, closed_world_predicates):
threshold_quantifier_type = threshold[1][1]
if threshold_quantifier_type == 'total':
neigh_len = len(grounding)
# Available is all neighbors that have a particular label with bound inside [0,1]
elif threshold_quantifier_type == 'available':
neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1), closed_world_predicates))
qualified_neigh_len = len(qualified_grounding)
satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
return satisfaction
@numba.njit(cache=True)
[docs]
def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold, closed_world_predicates):
threshold_quantifier_type = threshold[1][1]
if threshold_quantifier_type == 'total':
neigh_len = len(grounding)
# Available is all neighbors that have a particular label with bound inside [0,1]
elif threshold_quantifier_type == 'available':
neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1), closed_world_predicates))
qualified_neigh_len = len(qualified_grounding)
satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
return satisfaction
@numba.njit(cache=True)
[docs]
def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes):
# The groundings for a node clause can be either a previous grounding or all possible nodes
if l in predicate_map:
grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1]
else:
grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1]
return grounding
@numba.njit(cache=True)
[docs]
def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges):
# There are 4 cases for predicate(Y,Z):
# 1. Both predicate variables Y and Z have not been encountered before
# 2. The source variable Y has not been encountered before but the target variable Z has
# 3. The target variable Z has not been encountered before but the source variable Y has
# 4. Both predicate variables Y and Z have been encountered before
edge_groundings = numba.typed.List.empty_list(edge_type)
# Case 1:
# We replace Y by all nodes and Z by the neighbors of each of these nodes
if clause_var_1 not in groundings and clause_var_2 not in groundings:
if l in predicate_map:
edge_groundings = predicate_map[l]
else:
edge_groundings = edges
# Case 2:
# We replace Y by the sources of Z
elif clause_var_1 not in groundings and clause_var_2 in groundings:
for n in groundings[clause_var_2]:
es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]])
edge_groundings.extend(es)
# Case 3:
# We replace Z by the neighbors of Y
elif clause_var_1 in groundings and clause_var_2 not in groundings:
for n in groundings[clause_var_1]:
es = numba.typed.List([(n, nn) for nn in neighbors[n]])
edge_groundings.extend(es)
# Case 4:
# We have seen both variables before
else:
# We have already seen these two variables in an edge clause
if (clause_var_1, clause_var_2) in groundings_edges:
edge_groundings = groundings_edges[(clause_var_1, clause_var_2)]
# We have seen both these variables but not in an edge clause together
else:
groundings_clause_var_2_set = set(groundings[clause_var_2])
for n in groundings[clause_var_1]:
es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set])
edge_groundings.extend(es)
return edge_groundings
@numba.njit(cache=True)
[docs]
def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd, closed_world_predicates):
# Filter the grounding by the predicate and bound of the clause
qualified_groundings = numba.typed.List.empty_list(node_type)
for n in grounding:
if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd), closed_world_predicates):
qualified_groundings.append(n)
return qualified_groundings
@numba.njit(cache=True)
[docs]
def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd, closed_world_predicates):
# Filter the grounding by the predicate and bound of the clause
qualified_groundings = numba.typed.List.empty_list(edge_type)
for e in grounding:
if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd), closed_world_predicates):
qualified_groundings.append(e)
return qualified_groundings
@numba.njit(cache=True)
def _satisfies_threshold(num_neigh, num_qualified_component, threshold):
# Checks if qualified neighbors satisfy threshold. This is for one clause
if threshold[1][0]=='number':
if threshold[0]=='greater_equal':
result = True if num_qualified_component >= threshold[2] else False
elif threshold[0]=='greater':
result = True if num_qualified_component > threshold[2] else False
elif threshold[0]=='less_equal':
result = True if num_qualified_component <= threshold[2] else False
elif threshold[0]=='less':
result = True if num_qualified_component < threshold[2] else False
elif threshold[0]=='equal':
result = True if num_qualified_component == threshold[2] else False
elif threshold[1][0]=='percent':
if num_neigh==0:
result = False
elif threshold[0]=='greater_equal':
result = True if num_qualified_component/num_neigh >= threshold[2]*0.01 else False
elif threshold[0]=='greater':
result = True if num_qualified_component/num_neigh > threshold[2]*0.01 else False
elif threshold[0]=='less_equal':
result = True if num_qualified_component/num_neigh <= threshold[2]*0.01 else False
elif threshold[0]=='less':
result = True if num_qualified_component/num_neigh < threshold[2]*0.01 else False
elif threshold[0]=='equal':
result = True if num_qualified_component/num_neigh == threshold[2]*0.01 else False
return result
@numba.njit(cache=True)
def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False):
updated = False
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
l, bnd = na
updated_bnds = numba.typed.List.empty_list(interval.interval_type)
# Add label to world if it is not there
if l not in world.world:
world.world[l] = interval.closed(0, 1)
num_ga[t_cnt] += 1
if l in predicate_map:
predicate_map[l].append(comp)
else:
predicate_map[l] = numba.typed.List([comp])
# Check if update is necessary with previous bnd
prev_bnd = world.world[l].copy()
# override will not check for inconsistencies
if override:
world.world[l].set_lower_upper(bnd.lower, bnd.upper)
else:
world.update(l, bnd)
world.world[l].set_static(static)
if world.world[l]!=prev_bnd:
updated = True
updated_bnds.append(world.world[l])
# Add to rule trace if update happened and add to atom trace if necessary
if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes:
# Determine triggered_by and meta_name for the trace tuple
if mode == 'fact' or mode == 'graph-attribute-fact':
triggered_by = 'Fact'
else:
triggered_by = 'Rule'
meta_name = ''
if atom_trace:
if mode=='fact' or mode=='graph-attribute-fact':
meta_name = facts_to_be_applied_trace[idx]
elif mode=='rule':
meta_name = rules_to_be_applied_trace[idx][2]
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy(), True, triggered_by, meta_name, ''))
if atom_trace:
# Mode can be fact or rule, updation of trace will happen accordingly
if mode=='fact' or mode=='graph-attribute-fact':
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
name = facts_to_be_applied_trace[idx]
_update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
elif mode=='rule':
qn, qe, name = rules_to_be_applied_trace[idx]
_update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
# Update complement of predicate (if exists) based on new knowledge of predicate
if updated:
ip_update_cnt = 0
for p1, p2 in ipl:
if p1 == l:
if p2 not in world.world:
world.world[p2] = interval.closed(0, 1)
if p2 in predicate_map:
predicate_map[p2].append(comp)
else:
predicate_map[p2] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
upper = min(world.world[p2].upper, 1 - world.world[p1].lower)
world.world[p2].set_lower_upper(lower, upper)
world.world[p2].set_static(static)
ip_update_cnt += 1
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper), True, 'IPL', f'IPL: {l.get_value()}', ''))
if p2 == l:
if p1 not in world.world:
world.world[p1] = interval.closed(0, 1)
if p1 in predicate_map:
predicate_map[p1].append(comp)
else:
predicate_map[p1] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
upper = min(world.world[p1].upper, 1 - world.world[p2].lower)
world.world[p1].set_lower_upper(lower, upper)
world.world[p1].set_static(static)
ip_update_cnt += 1
updated_bnds.append(world.world[p1])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper), True, 'IPL', f'IPL: {l.get_value()}', ''))
# Gather convergence data
change = 0
if updated:
# Find out if it has changed from previous interp
current_bnd = world.world[l]
prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper)
if current_bnd != prev_t_bnd:
if convergence_mode=='delta_bound':
for i in updated_bnds:
# Use each bound's own previous values instead of L's previous
prev_i_bnd = interval.closed(i.prev_lower, i.prev_upper)
lower_delta = abs(i.lower - prev_i_bnd.lower)
upper_delta = abs(i.upper - prev_i_bnd.upper)
max_delta = max(lower_delta, upper_delta)
change = max(change, max_delta)
else:
change = 1 + ip_update_cnt
return (updated, change)
except Exception:
return (False, 0)
@numba.njit(cache=True)
def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False):
updated = False
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
l, bnd = na
updated_bnds = numba.typed.List.empty_list(interval.interval_type)
# Add label to world if it is not there
if l not in world.world:
world.world[l] = interval.closed(0, 1)
num_ga[t_cnt] += 1
if l in predicate_map:
predicate_map[l].append(comp)
else:
predicate_map[l] = numba.typed.List([comp])
# Check if update is necessary with previous bnd
prev_bnd = world.world[l].copy()
# override will not check for inconsistencies
if override:
world.world[l].set_lower_upper(bnd.lower, bnd.upper)
else:
world.update(l, bnd)
world.world[l].set_static(static)
if world.world[l]!=prev_bnd:
updated = True
updated_bnds.append(world.world[l])
# Add to rule trace if update happened and add to atom trace if necessary
if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes:
# Determine triggered_by and meta_name for the trace tuple
if mode == 'fact' or mode == 'graph-attribute-fact':
triggered_by = 'Fact'
else:
triggered_by = 'Rule'
meta_name = ''
if atom_trace:
if mode=='fact' or mode=='graph-attribute-fact':
meta_name = facts_to_be_applied_trace[idx]
elif mode=='rule':
meta_name = rules_to_be_applied_trace[idx][2]
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy(), True, triggered_by, meta_name, ''))
if atom_trace:
# Mode can be fact or rule, updation of trace will happen accordingly
if mode=='fact' or mode=='graph-attribute-fact':
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
name = facts_to_be_applied_trace[idx]
_update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
elif mode=='rule':
qn, qe, name = rules_to_be_applied_trace[idx]
_update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
# Update complement of predicate (if exists) based on new knowledge of predicate
if updated:
ip_update_cnt = 0
for p1, p2 in ipl:
if p1 == l:
if p2 not in world.world:
world.world[p2] = interval.closed(0, 1)
if p2 in predicate_map:
predicate_map[p2].append(comp)
else:
predicate_map[p2] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
upper = min(world.world[p2].upper, 1 - world.world[p1].lower)
world.world[p2].set_lower_upper(lower, upper)
world.world[p2].set_static(static)
ip_update_cnt += 1
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper), True, 'IPL', f'IPL: {l.get_value()}', ''))
if p2 == l:
if p1 not in world.world:
world.world[p1] = interval.closed(0, 1)
if p1 in predicate_map:
predicate_map[p1].append(comp)
else:
predicate_map[p1] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
upper = min(world.world[p1].upper, 1 - world.world[p2].lower)
world.world[p1].set_lower_upper(lower, upper)
world.world[p1].set_static(static)
ip_update_cnt += 1
updated_bnds.append(world.world[p1])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper), True, 'IPL', f'IPL: {l.get_value()}', ''))
# Gather convergence data
change = 0
if updated:
# Find out if it has changed from previous interp
current_bnd = world.world[l]
prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper)
if current_bnd != prev_t_bnd:
if convergence_mode=='delta_bound':
for i in updated_bnds:
# Use each bound's own previous values instead of L's previous
prev_i_bnd = interval.closed(i.prev_lower, i.prev_upper)
lower_delta = abs(i.lower - prev_i_bnd.lower)
upper_delta = abs(i.upper - prev_i_bnd.upper)
max_delta = max(lower_delta, upper_delta)
change = max(change, max_delta)
else:
change = 1 + ip_update_cnt
return (updated, change)
except Exception:
return (False, 0)
@numba.njit(cache=True)
def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name):
rule_trace.append((qn, qe, prev_bnd.copy(), name))
@numba.njit(cache=True)
[docs]
def are_satisfied_node(interpretations, comp, nas, closed_world_predicates):
result = True
for (l, bnd) in nas:
result = result and is_satisfied_node(interpretations, comp, (l, bnd), closed_world_predicates)
return result
@numba.njit(cache=True)
[docs]
def is_satisfied_node(interpretations, comp, na, closed_world_predicates):
result = False
if not (na[0] is None or na[1] is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
# closed_world predicate check
if na[0] in closed_world_predicates:
if na[0] not in world.world:
# Label not in world — missing = unknown [0,1] = treat as [0,0]
result = interval.closed(0, 0) in na[1]
return result
world_bnd = world.world[na[0]]
if world_bnd.lower == 0.0 and world_bnd.upper == 1.0:
result = interval.closed(0, 0) in na[1]
return result
result = world.is_satisfied(na[0], na[1])
except Exception:
result = False
else:
result = True
return result
@numba.njit(cache=True)
[docs]
def is_satisfied_node_comparison(interpretations, comp, na):
result = False
number = 0
l, bnd = na
l_str = l.value
if not (l is None or bnd is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
for world_l in world.world.keys():
world_l_str = world_l.value
if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit():
# The label is contained in the world
result = world.is_satisfied(world_l, na[1])
# Find the suffix number
number = str_to_float(world_l_str[len(l_str)+1:])
break
except Exception:
result = False
else:
result = True
return result, number
@numba.njit(cache=True)
[docs]
def are_satisfied_edge(interpretations, comp, nas, closed_world_predicates):
result = True
for (l, bnd) in nas:
result = result and is_satisfied_edge(interpretations, comp, (l, bnd), closed_world_predicates)
return result
@numba.njit(cache=True)
[docs]
def is_satisfied_edge(interpretations, comp, na, closed_world_predicates):
result = False
if not (na[0] is None or na[1] is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
# closed_world predicate check
if na[0] in closed_world_predicates:
if na[0] not in world.world:
# Label not in world — missing = unknown [0,1] = treat as [0,0]
result = interval.closed(0, 0) in na[1]
return result
world_bnd = world.world[na[0]]
if world_bnd.lower == 0.0 and world_bnd.upper == 1.0:
result = interval.closed(0, 0) in na[1]
return result
result = world.is_satisfied(na[0], na[1])
except Exception:
result = False
else:
result = True
return result
@numba.njit(cache=True)
[docs]
def is_satisfied_edge_comparison(interpretations, comp, na):
result = False
number = 0
l, bnd = na
l_str = l.value
if not (l is None or bnd is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
for world_l in world.world.keys():
world_l_str = world_l.value
if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit():
# The label is contained in the world
result = world.is_satisfied(world_l, na[1])
# Find the suffix number
number = str_to_float(world_l_str[len(l_str)+1:])
break
except Exception:
result = False
else:
result = True
return result, number
@numba.njit(cache=True)
[docs]
def annotate(annotation_functions, rule, annotations, qualified_nodes, qualified_edges, clause_labels, clause_variables, weights):
"""Resolve and invoke the rule's annotation function.
If the rule has no annotation function attached, returns the rule's static
bound directly. Otherwise looks up the user-registered function by name
and dispatches based on arity (arity is validated at registration; see
``add_annotation_function``).
Two supported signatures:
- 2-arg legacy::
def fn(annotations, weights) -> Tuple[float, float]
- 6-arg extended::
def fn(annotations, weights,
qualified_nodes, qualified_edges,
clause_labels, clause_variables) -> Tuple[float, float]
The extended args expose the per-clause join structure the engine
already builds, so user code can identify clauses by predicate name +
variable role (robust to ``reorder_clauses``) and walk the actual
per-grounding pairings rather than the flattened per-atom bounds.
Per-clause alignment (extended signature)
-----------------------------------------
The five list-shaped args are indexed by the same per-clause position ``i``::
annotations[i] # List[Interval] — bounds of qualifying atoms
qualified_nodes[i] # List[Node] — nodes that qualified
qualified_edges[i] # List[Edge] — edges that qualified
clause_labels[i] # Label — predicate of this clause
clause_variables[i] # List[str] — variable names of this clause
Per-clause arity:
- Node clause: ``qualified_edges[i]`` is empty; ``len(clause_variables[i]) == 1``.
- Edge clause: ``qualified_nodes[i]`` is empty; ``len(clause_variables[i]) == 2``.
Comparison clauses are skipped during metadata collection, so
``len(clause_labels)`` may be less than ``len(rule.get_clauses())``.
Do not assume index parity with the original rule body — match clauses
by predicate name and variable role instead.
"""
func_name = rule.get_annotation_function()
if func_name == '':
return rule.get_bnd().lower, rule.get_bnd().upper
else:
with numba.objmode(annotation='Tuple((float64, float64))'):
for func in annotation_functions:
if func.__name__ == func_name:
# Arity is gated at registration time by
# `add_annotation_function`, so only 2 and 6 reach here.
# (Raising inside numba.objmode is unsupported, so we
# rely on the registration-time check for validation.)
py_func = getattr(func, 'py_func', func)
nargs = py_func.__code__.co_argcount
if nargs == 6:
annotation = func(annotations, weights, qualified_nodes, qualified_edges, clause_labels, clause_variables)
else:
annotation = func(annotations, weights)
return annotation
@numba.njit(cache=True)
[docs]
def check_consistent_node(interpretations, comp, na):
world = interpretations[comp]
if na[0] in world.world:
bnd = world.world[na[0]]
else:
bnd = interval.closed(0, 1)
if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper):
return False
else:
return True
@numba.njit(cache=True)
[docs]
def check_consistent_edge(interpretations, comp, na):
world = interpretations[comp]
if na[0] in world.world:
bnd = world.world[na[0]]
else:
bnd = interval.closed(0, 1)
if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper):
return False
else:
return True
@numba.njit(cache=True)
[docs]
def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
world = interpretations[comp]
# Determine triggered_by and actual_name
if mode == 'fact' or mode == 'graph-attribute-fact':
triggered_by = 'Fact'
else:
triggered_by = 'Rule'
actual_name = ''
if atom_trace:
if mode == 'fact' or mode == 'graph-attribute-fact':
actual_name = facts_to_be_applied_trace[idx]
elif mode == 'rule':
actual_name = rules_to_be_applied_trace[idx][2]
# Build descriptive inconsistency message
msg = ''
if atom_trace:
comp_label_value = ''
for _p1, _p2 in ipl:
if _p1 == na[0]:
comp_label_value = _p2.get_value()
break
if _p2 == na[0]:
comp_label_value = _p1.get_value()
break
if comp_label_value != '':
msg = f'Inconsistency occurred. Grounding {na[0].get_value()}({comp}) conflicts with grounding {comp_label_value}({comp}). Setting bounds to [0,1] and static=True for this timestep.'
else:
msg = f'Inconsistency occurred. Conflicting bounds for {na[0].get_value()}({comp}). Update from [{float_to_str(world.world[na[0]].lower)}, {float_to_str(world.world[na[0]].upper)}] to [{float_to_str(na[1].lower)}, {float_to_str(na[1].upper)}] is not allowed. Setting bounds to [0,1] and static=True for this timestep.'
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1), False, triggered_by, actual_name, msg))
if atom_trace:
if mode == 'rule':
qn, qe, _ = rules_to_be_applied_trace[idx]
else:
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
_update_rule_trace(rule_trace_atoms, qn, qe, world.world[na[0]], actual_name)
# Resolve inconsistency and set static
world.world[na[0]].set_lower_upper(0, 1)
world.world[na[0]].set_static(True)
for p1, p2 in ipl:
if p1==na[0]:
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], actual_name)
world.world[p2].set_lower_upper(0, 1)
world.world[p2].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1), False, 'IPL', actual_name, msg))
if p2==na[0]:
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], actual_name)
world.world[p1].set_lower_upper(0, 1)
world.world[p1].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1), False, 'IPL', actual_name, msg))
@numba.njit(cache=True)
[docs]
def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
w = interpretations[comp]
# Determine triggered_by and actual_name
if mode == 'fact' or mode == 'graph-attribute-fact':
triggered_by = 'Fact'
else:
triggered_by = 'Rule'
actual_name = ''
if atom_trace:
if mode == 'fact' or mode == 'graph-attribute-fact':
actual_name = facts_to_be_applied_trace[idx]
elif mode == 'rule':
actual_name = rules_to_be_applied_trace[idx][2]
# Build descriptive inconsistency message
msg = ''
if atom_trace:
comp_label_value = ''
for _p1, _p2 in ipl:
if _p1 == na[0]:
comp_label_value = _p2.get_value()
break
if _p2 == na[0]:
comp_label_value = _p1.get_value()
break
if comp_label_value != '':
msg = f'Inconsistency occurred. Grounding {na[0].get_value()}({comp[0]},{comp[1]}) conflicts with grounding {comp_label_value}({comp[0]},{comp[1]}). Setting bounds to [0,1] and static=True for this timestep.'
else:
msg = f'Inconsistency occurred. Conflicting bounds for {na[0].get_value()}({comp[0]},{comp[1]}). Update from [{float_to_str(w.world[na[0]].lower)}, {float_to_str(w.world[na[0]].upper)}] to [{float_to_str(na[1].lower)}, {float_to_str(na[1].upper)}] is not allowed. Setting bounds to [0,1] and static=True for this timestep.'
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1), False, triggered_by, actual_name, msg))
if atom_trace:
if mode == 'rule':
qn, qe, _ = rules_to_be_applied_trace[idx]
else:
qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
_update_rule_trace(rule_trace_atoms, qn, qe, w.world[na[0]], actual_name)
# Resolve inconsistency and set static
w.world[na[0]].set_lower_upper(0, 1)
w.world[na[0]].set_static(True)
for p1, p2 in ipl:
if p1==na[0]:
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], actual_name)
w.world[p2].set_lower_upper(0, 1)
w.world[p2].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1), False, 'IPL', actual_name, msg))
if p2==na[0]:
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], actual_name)
w.world[p1].set_lower_upper(0, 1)
w.world[p1].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1), False, 'IPL', actual_name, msg))
@numba.njit(cache=True)
def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
nodes.append(node)
neighbors[node] = numba.typed.List.empty_list(node_type)
reverse_neighbors[node] = numba.typed.List.empty_list(node_type)
interpretations_node[node] = world.World(numba.typed.List.empty_list(label.label_type))
@numba.njit(cache=True)
def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t):
# If not a node, add to list of nodes and initialize neighbors
if source not in nodes:
_add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node)
if target not in nodes:
_add_node(target, neighbors, reverse_neighbors, nodes, interpretations_node)
# Make sure edge doesn't already exist
# Make sure, if l=='', not to add the label
# Make sure, if edge exists, that we don't override the l label if it exists
edge = (source, target)
new_edge = False
if edge not in edges:
new_edge = True
edges.append(edge)
neighbors[source].append(target)
reverse_neighbors[target].append(source)
if l.value!='':
interpretations_edge[edge] = world.World(numba.typed.List([l]))
num_ga[t] += 1
if l in predicate_map:
predicate_map[l].append(edge)
else:
predicate_map[l] = numba.typed.List([edge])
else:
interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type))
else:
if l not in interpretations_edge[edge].world and l.value!='':
new_edge = True
interpretations_edge[edge].world[l] = interval.closed(0, 1)
num_ga[t] += 1
if l in predicate_map:
predicate_map[l].append(edge)
else:
predicate_map[l] = numba.typed.List([edge])
return edge, new_edge
@numba.njit(cache=True)
def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t):
changes = 0
edges_added = numba.typed.List.empty_list(edge_type)
for source in sources:
for target in targets:
edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t)
edges_added.append(edge)
changes = changes+1 if new_edge else changes
return edges_added, changes
@numba.njit(cache=True)
def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map, num_ga):
source, target = edge
edges.remove(edge)
num_ga[-1] -= len(interpretations_edge[edge].world)
del interpretations_edge[edge]
for l in predicate_map:
if edge in predicate_map[l]:
predicate_map[l].remove(edge)
neighbors[source].remove(target)
reverse_neighbors[target].remove(source)
@numba.njit(cache=True)
def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map, num_ga):
nodes.remove(node)
num_ga[-1] -= len(interpretations_node[node].world)
del interpretations_node[node]
del neighbors[node]
del reverse_neighbors[node]
for l in predicate_map:
if node in predicate_map[l]:
predicate_map[l].remove(node)
# Remove all occurrences of node in neighbors
for n in neighbors.keys():
if node in neighbors[n]:
neighbors[n].remove(node)
for n in reverse_neighbors.keys():
if node in reverse_neighbors[n]:
reverse_neighbors[n].remove(node)
@numba.njit(cache=True)
[docs]
def float_to_str(value):
number = int(value)
decimal = int(round(abs(value) % 1 * 1000))
# Manual zero-padding (numba may not support :03d in f-strings)
if decimal < 10:
decimal_str = f'00{decimal}'
elif decimal < 100:
decimal_str = f'0{decimal}'
else:
decimal_str = f'{decimal}'
# Handle negative values where int() truncates to 0 (e.g., -0.123)
if value < 0 and number == 0:
float_str = f'-{number}.{decimal_str}'
else:
float_str = f'{number}.{decimal_str}'
return float_str
@numba.njit(cache=True)
[docs]
def str_to_float(value):
decimal_pos = value.find('.')
if decimal_pos != -1:
after_decimal_len = len(value[decimal_pos+1:])
else:
after_decimal_len = 0
value = value.replace('.', '')
value = str_to_int(value)
value = value / 10**after_decimal_len
return value
@numba.njit(cache=True)
[docs]
def str_to_int(value):
if value[0] == '-':
negative = True
value = value.replace('-','')
else:
negative = False
final_index, result = len(value) - 1, 0
for i, v in enumerate(value):
result += (ord(v) - 48) * (10 ** (final_index - i))
result = -result if negative else result
return result
@numba.njit(cache=True)
def _determine_node_head_vars(head_fns, head_fns_vars, groundings, head_functions):
"""
Determine the actual head groundings by applying head functions if needed.
Args:
head_fns: List of function names for each head variable (empty string if no function)
head_fns_vars: List of variable names that are arguments to each function
groundings: Dictionary mapping variable names to their grounded node values
head_functions: Tuple of available head functions
Returns:
List of head groundings
"""
# head_var_groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
head_groundings = numba.typed.List.empty_list(node_type)
is_func = False
# For node rule only one element
fn_name = head_fns[0]
fn_vars = head_fns_vars[0]
# If there's no function, just use the variable's grounding
if fn_name != '' and len(fn_vars) > 0:
# Apply the function to compute the grounding
# First, collect the grounded values for the function's arguments
fn_arg_values = numba.typed.List.empty_list(list_of_nodes)
for fn_var in fn_vars:
if fn_var in groundings:
fn_arg_values.append(groundings[fn_var])
else:
# If variable not grounded, treat it as itself
fn_arg_values.append(numba.typed.List([fn_var]))
# Call the head function and get result
head_groundings = _call_head_function(fn_name, fn_arg_values, head_functions)
is_func = True
return head_groundings, is_func
@numba.njit(cache=True)
def _determine_edge_head_vars(head_fns, head_fns_vars, groundings, head_functions):
"""
Determine the actual head groundings by applying head functions if needed.
Args:
head_fns: List of function names for each head variable (empty string if no function)
head_fns_vars: List of variable names that are arguments to each function
groundings: Dictionary mapping variable names to their grounded node values
head_functions: Tuple of available head functions
Returns:
List of head groundings
"""
head_groundings = numba.typed.List.empty_list(list_of_nodes)
head_groundings.append(numba.typed.List.empty_list(node_type)) # For source
head_groundings.append(numba.typed.List.empty_list(node_type)) # For target
is_func = numba.typed.List([False, False])
# For edge rule only two elements
for i in range(2):
fn_name = head_fns[i]
fn_vars = head_fns_vars[i]
# If there's no function, just use the variable's grounding
if fn_name != '' and len(fn_vars) > 0:
# Apply the function to compute the grounding
# First, collect the grounded values for the function's arguments
fn_arg_values = numba.typed.List.empty_list(list_of_nodes)
for fn_var in fn_vars:
if fn_var in groundings:
fn_arg_values.append(groundings[fn_var])
else:
# If variable not grounded, treat it as itself
fn_arg_values.append(numba.typed.List([fn_var]))
# Call the head function and get result
head_grounding = _call_head_function(fn_name, fn_arg_values, head_functions)
head_groundings[i] = head_grounding
is_func[i] = True
return head_groundings, is_func
@numba.njit(cache=True)
def _call_head_function(fn_name, fn_arg_values, head_functions):
"""
Call a head function with the given arguments.
Args:
fn_name: Name of the function to call
fn_arg_values: List of arguments (each is a list of node strings)
head_functions: Tuple of available head functions
Returns:
Flattened list of node strings from the function result
"""
# Use objmode to call the Python function and get raw Python result
# We need to return a numba typed list, so we return it via objmode
func_result = numba.typed.List.empty_list(node_type)
with numba.objmode(func_result='types.ListType(types.unicode_type)'):
for func in head_functions:
if hasattr(func, '__name__') and func.__name__ == fn_name:
func_result = func(fn_arg_values)
break
return func_result