import math
import re
import numba
import numpy as np
from typing import Union
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
# import pyreason.scripts.rules.rule_internal as rule
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
from pyreason.scripts.threshold.threshold import Threshold
_PREDICATE_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_.\-]*$')
_COMPONENT_RE = re.compile(r'^[a-zA-Z0-9_][a-zA-Z0-9_.@\-]*$')
[docs]
def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, dict], infer_edges: bool = False, set_static: bool = False, weights: Union[None, np.ndarray] = None) -> rule.Rule:
# --- Group A: Entry-point validation ---
# V1: rule_text must be a string
if not isinstance(rule_text, str):
raise TypeError(f"rule_text must be a string, got {type(rule_text).__name__}")
# V2: rule_text cannot be empty or whitespace-only
if not rule_text.strip():
raise ValueError("rule_text cannot be empty or whitespace only")
# V3: Must contain exactly one '<-' separator
arrow_count = rule_text.count('<-')
if arrow_count != 1:
raise ValueError(
f"Rule must contain exactly one '<-' separator, found {arrow_count}. "
"Use the format: 'head(X) <- body(X)'"
)
# First remove all spaces from line
rule_str = rule_text.replace(' ', '')
# Separate into head and body
head, body = rule_str.split('<-')
# V4: Head cannot be empty after split
if not head:
raise ValueError("Rule head cannot be empty")
# Handle empty body (valid rule - always fires unconditionally)
if not body:
delta_t = 0
body_clauses = []
body_bounds = []
else:
# Extract delta_t of rule if it exists else set it to 0
delta_t = ''
is_digit = True
while is_digit:
if body[0].isdigit():
delta_t += body[0]
body = body[1:]
else:
is_digit = False
if delta_t == '':
delta_t = 0
else:
delta_t = int(delta_t)
# Split the body into clauses and their bounds
body_clauses, body_bounds = _split_body_into_clauses(body)
# Handle forall quantifier in body clauses
for i, clause_str in enumerate(body_clauses.copy()):
if 'forall(' in clause_str:
# V14: Validate forall syntax — must end with ')' (the outer forall paren)
if not clause_str.endswith(')'):
raise ValueError(f"Malformed forall expression: '{clause_str}'. Expected 'forall(pred(vars))'")
# Extract inner expression and validate it contains a predicate with variables
inner = clause_str[:-1].replace('forall(', '', 1)
if '(' not in inner or ')' not in inner:
raise ValueError(f"forall expression must contain an inner predicate with variables, got 'forall({inner})'")
if not custom_thresholds:
custom_thresholds = {}
custom_thresholds[i] = Threshold("greater_equal", ("percent", "total"), 100)
body_clauses[i] = inner
# Parse the head: target predicate, bound, annotation function, and head-negation flag
head, target_bound, ann_fn, head_negated = _parse_head(head)
idx = head.find('(')
target = head[:idx]
_validate_predicate_name(target, "Head")
target = label.Label(target)
# Variable(s) in the head of the rule - now supports functions like f(X, Y)
# Find the last ')' to handle nested function calls
end_idx = head.rfind(')')
head_args_str = head[idx + 1:end_idx]
# Parse head arguments which can be variables or function calls
head_variables, head_fns, head_fns_vars = _parse_head_arguments(head_args_str)
# Validate head has at least one variable
if len(head_variables) < 1:
raise ValueError("Rule head must contain at least one argument inside parentheses")
# Validate head variable names
for var in head_variables:
_validate_component_name(var, "Head")
# Assign type of rule
rule_type = 'node' if len(head_variables) == 1 else 'edge'
# Get the variables in the body
# If there's an operator in the body then discard anything that comes after the operator, but keep the variables
body_predicates = []
body_variables = []
for clause_str in body_clauses:
# V8: Body clause must contain parentheses
start_idx = clause_str.find('(')
if start_idx == -1:
raise ValueError(f"Body clause '{clause_str}' must contain parentheses around argument")
end_idx = clause_str.find(')')
if end_idx == -1:
raise ValueError(f"Body clause '{clause_str}' is missing closing parenthesis")
pred_name = clause_str[:start_idx]
_validate_predicate_name(pred_name, "Body")
body_predicates.append(pred_name)
# Add body variables
variables = clause_str[start_idx+1:end_idx].split(',')
start_idx = clause_str.find('(', start_idx+1)
end_idx = clause_str.find(')', end_idx+1)
if start_idx != -1 and end_idx != -1:
variables += clause_str[start_idx+1:end_idx].split(',')
# Validate body variable names
for var in variables:
_validate_component_name(var, "Body")
body_variables.append(variables)
# Change infer edge parameter if it's a node rule
if rule_type == 'node':
infer_edges = False
# Start setting up clauses
# clauses = [c1, c2, c3, c4]
# thresholds = [t1, t2, t3, t4]
# Array of thresholds to keep track of for each neighbor criterion. Form [(comparison, (number/percent, total/available), thresh)]
thresholds = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, numba.types.UniTuple(numba.types.string, 2), numba.types.float64)))
# Array to store clauses for nodes: node/edge, [subset]/[subset1, subset2], label, interval, operator
clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string)))
# gather count of clauses for threshold validation
num_clauses = len(body_clauses)
if isinstance(custom_thresholds, list):
if len(custom_thresholds) != num_clauses:
raise ValueError(f'The length of custom thresholds {len(custom_thresholds)} is not equal to number of clauses {num_clauses}')
for threshold in custom_thresholds:
thresholds.append(threshold.to_tuple())
elif isinstance(custom_thresholds, dict):
# V12: Empty dict is not allowed
if len(custom_thresholds) == 0:
raise ValueError("custom_thresholds dict cannot be empty. Use None for default thresholds")
if any(k < 0 for k in custom_thresholds.keys()):
raise ValueError("custom_thresholds dict keys must be non-negative integers")
if max(custom_thresholds.keys()) >= num_clauses:
raise ValueError(f'The max clause index in the custom thresholds map {max(custom_thresholds.keys())} is greater than number of clauses {num_clauses}')
for i in range(num_clauses):
if i in custom_thresholds:
thresholds.append(custom_thresholds[i].to_tuple())
else:
thresholds.append(('greater_equal', ('number', 'total'), 1.0))
# If no custom thresholds provided, use defaults
# otherwise loop through user-defined thresholds and convert to numba compatible format
elif not custom_thresholds:
for _ in range(num_clauses):
thresholds.append(('greater_equal', ('number', 'total'), 1.0))
# # Loop though clauses
for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds):
# Neigh criteria
clause_type = 'node' if len(variables) == 1 else 'edge'
op = _get_operator_from_clause(body_clause)
if op:
clause_type = 'comparison'
subset = numba.typed.List(variables)
label_obj = label.Label(predicate)
bnd = interval.closed(bounds[0], bounds[1])
clauses.append((clause_type, label_obj, subset, bnd, op))
# Assert that there are two variables in the head of the rule if we infer edges
# Add edges between head variables if necessary
if infer_edges:
# var = '__target' if head_variables[0] == head_variables[1] else head_variables[1]
# edges = ('__target', var, target)
edges = (head_variables[0], head_variables[1], target)
else:
edges = ('', '', label.Label(''))
if weights is None:
weights = np.ones(len(body_predicates), dtype=np.float64)
else:
# V13: Validate weights array
# Check if it's array-like
if not isinstance(weights, np.ndarray):
try:
weights = np.array(weights, dtype=np.float64)
except (ValueError, TypeError):
raise TypeError(f"weights must be a numpy array or convertible to one, got {type(weights).__name__}")
# Check length matches number of clauses
if len(weights) != len(body_predicates):
raise ValueError(f'Number of weights {len(weights)} is not equal to number of clauses {len(body_predicates)}')
# Check all values are numeric and finite
if not np.issubdtype(weights.dtype, np.number):
raise TypeError(f"weights must contain numeric values, got dtype {weights.dtype}")
if not np.all(np.isfinite(weights)):
raise ValueError("weights must contain only finite values (no NaN or Inf)")
# Check all values are non-negative
if np.any(weights < 0):
raise ValueError("weights must be non-negative")
# Ensure correct dtype for numba compatibility
weights = weights.astype(np.float64)
head_variables = numba.typed.List(head_variables)
# Convert head functions and their variables to numba types
head_fns_numba = numba.typed.List(head_fns)
head_fns_vars_numba = numba.typed.List.empty_list(numba.types.ListType(numba.types.string))
for vars_list in head_fns_vars:
typed_vars_list = numba.typed.List.empty_list(numba.types.string)
for var in vars_list:
typed_vars_list.append(var)
head_fns_vars_numba.append(typed_vars_list)
result = rule.Rule(name, rule_type, target, head_variables, numba.types.uint16(delta_t), clauses, target_bound, thresholds, ann_fn, weights, head_fns_numba, head_fns_vars_numba, edges, set_static, head_negated)
return result
def _split_body_into_clauses(body):
"""Split the body string into (body_clauses, body_bounds) lists.
Uses a double-character trick to split on clause boundaries without
destroying closing delimiters that are part of the clause content:
1. Double-up ')' and ']' so that splitting on '),' or '],' always
leaves one copy of the delimiter inside the clause string.
2. Split first on '),' then on '],' to handle both unbracketed and
bracketed clause endings.
3. Restore the original single characters.
4. Attach default bounds :[1,1] (or :[0,0] for negated clauses).
5. Split each clause on ':' to separate predicate from bound.
"""
# Convert :True/:False shorthand to numeric bounds before splitting (case-insensitive)
# Must happen before delimiter doubling so clauses end with ']' for correct splitting
body_lower = body.lower()
new_body = []
i = 0
while i < len(body):
if body_lower[i:i + 5] == ':true' and (i + 5 >= len(body) or body[i + 5] in ',)'):
new_body.append(':[1,1]')
i += 5
elif body_lower[i:i + 6] == ':false' and (i + 6 >= len(body) or body[i + 6] in ',)'):
new_body.append(':[0,0]')
i += 6
else:
new_body.append(body[i])
i += 1
body = ''.join(new_body)
# Double-up closing delimiters so splitting on ")," / "]," is safe
body = body.replace(')', '))')
body = body.replace(']', ']]')
# Split on clause boundaries: first '),' then '],'
body = body.split('),')
split_body = []
for part in body:
split_body.extend(part.split('],'))
# Restore original single delimiters
for i in range(len(split_body)):
split_body[i] = split_body[i].replace('))', ')')
split_body[i] = split_body[i].replace(']]', ']')
# V7: Check for empty or malformed clauses (e.g. trailing commas, double commas)
for i, part in enumerate(split_body):
stripped = part.lstrip(',')
if not stripped:
raise ValueError(f"Body clause {i} is empty. Check for trailing commas or double commas in the rule body")
# Leading comma indicates consecutive commas in the original rule
if stripped != part:
raise ValueError(f"Body clause {i} is empty. Check for trailing commas or double commas in the rule body")
# Check for double negation in body clauses
for part in split_body:
if part.startswith('~~'):
raise ValueError(f"Double negation '~~' is not allowed in body clause '{part}'")
# Attach default bounds: negated clauses get [0,0], others get [1,1]
# Track which clauses are negated with explicit bounds (need [1-u, 1-l] transform)
negate_body_flags = [False] * len(split_body)
for i in range(len(split_body)):
if split_body[i][0] == '~':
if split_body[i][-1] == ']':
# ~pred(X):[l,u] — strip negation, keep explicit bound, flag for inversion
split_body[i] = split_body[i][1:]
negate_body_flags[i] = True
else:
# ~pred(X) — strip negation, assign [0,0]
split_body[i] = split_body[i][1:] + ':[0,0]'
elif split_body[i][-1] != ']':
split_body[i] += ':[1,1]'
# Separate each clause into predicate and bound string
body_clauses = []
body_bounds = []
for part in split_body:
# V9: Each clause must split into exactly predicate:bound
parts = part.split(':')
if len(parts) != 2:
raise ValueError(f"Body clause '{part}' has invalid format: expected exactly one ':' separating predicate from bound")
clause_str, bound_str = parts
body_clauses.append(clause_str)
body_bounds.append(bound_str)
# Convert bound strings to [lower, upper] float pairs
for i in range(len(body_bounds)):
bound_str = body_bounds[i]
lower, upper = _str_bound_to_bound(bound_str)
# Apply negation inversion: ~[l,u] = [1-u, 1-l]
if negate_body_flags[i]:
lower, upper = round(1 - upper, 10), round(1 - lower, 10)
body_bounds[i] = [lower, upper]
return body_clauses, body_bounds
def _parse_head(head):
"""Parse the head string into (head_str, target_bound, ann_fn).
Possible head formats:
- pred(x) → default bound [1,1], no annotation
- ~pred(x) → negated bound [0,0]
- pred(x):[l,u] → explicit bound
- ~pred(x):[l,u] → negated explicit bound
- pred(x):fn_name → annotation function with default bound [0,1]
"""
# Check for double negation in head
if head.startswith('~~'):
raise ValueError(f"Double negation '~~' is not allowed in rule head '{head}'")
# V5 (preliminary): head must contain '(' and ')'
if '(' not in head:
raise ValueError(f"Rule head '{head}' must contain parentheses around variables")
if ')' not in head:
raise ValueError(f"Rule head '{head}' is missing closing parenthesis")
# V6: At most one colon allowed in head
colon_count = head.count(':')
if colon_count > 1:
raise ValueError(f"Rule head contains {colon_count} colons, expected at most 1")
# Strip a leading '~' up front so the suffix-handling below is uniform across
# all forms: ~pred(X), ~pred(X):[l,u], ~pred(X):True/False, ~pred(X):ann_fn.
# We apply the [1-u, 1-l] inversion to the resolved target_bound below.
negate_head_interval = False
if head[0] == '~':
head = head[1:]
negate_head_interval = True
# Convert :True/:False shorthand to numeric bounds (case-insensitive)
if colon_count == 1:
colon_idx = head.index(':')
suffix = head[colon_idx + 1:]
if suffix.lower() == 'true':
head = head[:colon_idx] + ':[1,1]'
elif suffix.lower() == 'false':
head = head[:colon_idx] + ':[0,0]'
# If no colon present, attach default bound [1,1] (negation will flip it to [0,0])
if head[-1] == ')':
head += ':[1,1]'
head_str, head_bound_str = head.split(':')
# Determine if head_bound_str is a numeric bound or an annotation function name
if _is_bound(head_bound_str):
target_bound = list(_str_bound_to_bound(head_bound_str))
if negate_head_interval:
target_bound = [round(1 - target_bound[1], 10), round(1 - target_bound[0], 10)]
target_bound = interval.closed(*target_bound)
ann_fn = ''
else:
# If it looks like a bound (has brackets) but failed _is_bound, it's malformed
if '[' in head_bound_str and ']' in head_bound_str:
try:
_str_bound_to_bound(head_bound_str)
except ValueError as e:
raise ValueError(
f"{e}. Note: Annotation function names cannot contain brackets '[' or ']'"
)
# Annotation function: default bound is [0,1]; ~[0,1] = [0,1] is a no-op
# at parse time. To actually invert the ann_fn's runtime output, the
# negate_head_interval flag must be plumbed onto Rule.
target_bound = interval.closed(0, 1)
ann_fn = head_bound_str
return head_str, target_bound, ann_fn, negate_head_interval
def _parse_head_arguments(head_args_str):
"""
Parse head arguments which can be either simple variables or function calls.
Examples:
"X" -> head_variables=['X'], head_fns=[''], head_fns_vars=[[]]
"X, Y" -> head_variables=['X', 'Y'], head_fns=['', ''], head_fns_vars=[[], []]
"f(X, Y)" -> head_variables=['__temp_var_0'], head_fns=['f'], head_fns_vars=[['X', 'Y']]
"f(X, Y), Z" -> head_variables=['__temp_var_0', 'Z'], head_fns=['f', ''], head_fns_vars=[['X', 'Y'], []]
"f(X, Y), g(A, B)" -> head_variables=['__temp_var_0', '__temp_var_1'], head_fns=['f', 'g'], head_fns_vars=[['X', 'Y'], ['A', 'B']]
"""
head_variables = []
head_fns = []
head_fns_vars = []
if not head_args_str:
return head_variables, head_fns, head_fns_vars
# Split arguments by comma, being careful about nested parentheses
args_list = []
current_arg = ''
paren_count = 0
for char in head_args_str:
if char == '(':
paren_count += 1
current_arg += char
elif char == ')':
paren_count -= 1
current_arg += char
elif char == ',' and paren_count == 0:
args_list.append(current_arg.strip())
current_arg = ''
else:
current_arg += char
# Add the last argument
if current_arg.strip():
args_list.append(current_arg.strip())
# Parse each argument
for arg in args_list:
arg = arg.strip()
# Check if it's a function call (contains '(' and ')')
if '(' in arg and ')' in arg:
# Extract function name and arguments
paren_idx = arg.find('(')
fn_name = arg[:paren_idx]
# Extract arguments inside the function
fn_args_str = arg[paren_idx + 1:arg.rfind(')')]
fn_args = [a.strip() for a in fn_args_str.split(',') if a.strip()]
# Create a temporary variable name for this function result
temp_var = f'__temp_var_{len(head_variables)}'
head_variables.append(temp_var)
head_fns.append(fn_name)
head_fns_vars.append(fn_args)
else:
# It's a simple variable
head_variables.append(arg)
head_fns.append('')
head_fns_vars.append([])
return head_variables, head_fns, head_fns_vars
def _validate_predicate_name(pred, context):
"""Validate that a predicate name matches ^[a-zA-Z_][a-zA-Z0-9_.\\-]*$."""
if not _PREDICATE_RE.match(pred):
if pred and pred[0].isdigit():
raise ValueError(f"{context} predicate name '{pred}' cannot start with a digit")
raise ValueError(f"{context} predicate name '{pred}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_.\\-]*")
def _validate_component_name(var, context):
"""Validate that a component name matches ^[a-zA-Z0-9_][a-zA-Z0-9_.@\\-]*$."""
if not _COMPONENT_RE.match(var):
raise ValueError(f"{context} component name '{var}' contains invalid characters. Must match [a-zA-Z0-9_][a-zA-Z0-9_.@\\-]*")
def _str_bound_to_bound(str_bound):
"""Convert a string bound like '[0.5,0.8]' to (float, float).
Validates that:
- There are exactly 2 comma-separated values
- Both values are numeric
- Both values are in [0, 1]
- Lower <= upper
"""
str_bound = str_bound.replace('[', '')
str_bound = str_bound.replace(']', '')
parts = str_bound.split(',')
# V10: Must have exactly 2 values
if len(parts) != 2:
raise ValueError(f"Bound must contain exactly 2 values, got {len(parts)}: '{str_bound}'")
lower_str, upper_str = parts
# V10: Values must be numeric
try:
lower = float(lower_str)
except ValueError:
raise ValueError(f"Bound lower value must be numeric, got '{lower_str}'")
try:
upper = float(upper_str)
except ValueError:
raise ValueError(f"Bound upper value must be numeric, got '{upper_str}'")
# V10: Values must be finite numbers
if math.isnan(lower):
raise ValueError(f"Bound lower value must be a number, got '{lower_str}'")
if math.isnan(upper):
raise ValueError(f"Bound upper value must be a number, got '{upper_str}'")
# V10: Values must be in [0, 1]
if lower < 0 or lower > 1:
raise ValueError(f"Bound lower value {lower} is out of range [0, 1]")
if upper < 0 or upper > 1:
raise ValueError(f"Bound upper value {upper} is out of range [0, 1]")
# V10: Lower must not exceed upper
if lower > upper:
raise ValueError(f"Bound lower value {lower} is greater than upper value {upper}")
return lower, upper
def _is_bound(str_bound):
"""Check whether str_bound looks like a numeric bound (e.g. '[0.5,0.8]')
rather than an annotation function name.
Uses float() parsing instead of isdigit() to correctly handle negative
numbers, scientific notation, etc.
"""
str_bound = str_bound.replace('[', '')
str_bound = str_bound.replace(']', '')
try:
lower, upper = str_bound.split(',')
# V11: Use float() instead of isdigit() for robust numeric detection
float(lower)
float(upper)
result = True
except (ValueError, AttributeError):
result = False
return result
def _get_operator_from_clause(clause):
operators = ['<=', '>=', '<', '>', '==', '!=']
for op in operators:
if op in clause:
return op
# No operator found in clause
return ''