import threading
import time
from datetime import timedelta
from typing import List, Optional, Union, Callable, Any
import torch.nn
import torch.nn.functional as F
import pyreason as pr
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
[docs]
class TemporalLogicIntegratedClassifier(LogicIntegrationBase):
"""
Wraps any torch.nn.Module whose forward(x) returns [N, C] logits (multi‐class),
but additionally polls in the background (either every N timesteps or every N seconds)
and injects new Facts into a PyReason logic program.
"""
def __init__(
self,
model,
class_names: List[str],
identifier: str = 'classifier',
interface_options: ModelInterfaceOptions = None,
logic_program=None,
poll_interval: Optional[Union[int, timedelta]] = None,
poll_condition: Optional[str] = None,
input_fn: Optional[Callable[[], Any]] = None,
):
"""
:param model: PyTorch model to be integrated.
:param class_names: List of class names for the model output.
:param identifier: Identifier for the model, used as the constant in the facts.
:param interface_options: Options for the model interface, including threshold and snapping behavior.
:param logic_program: PyReason logic program
:param poll_interval: How often to poll the model, either as:
- an integer number of PyReason timesteps or
- a `datetime.timedelta` object representing wall-clock time.
If `None`, polling is disabled.
:param poll_condition: The name of the predicate attached to the model that must be true to trigger a poll.
If `None`, the model will be polled every `poll_interval` time steps/seconds.
:param input_fn: Function to call to get the input to the model. This function should return a tensor.
"""
super().__init__(model, class_names, interface_options, identifier)
[docs]
self.class_names = class_names
[docs]
self.identifier = identifier
[docs]
self.interface_options = interface_options
[docs]
self.logic_program = logic_program
[docs]
self.poll_interval = poll_interval
[docs]
self.poll_condition = poll_condition
# normalize poll_interval
if isinstance(poll_interval, int):
self.poll_interval: Union[int, timedelta, None] = poll_interval
else:
self.poll_interval = poll_interval
# start the async polling task if configured
if self.poll_interval is not None and self.input_fn is not None:
# this schedules the background task
# self._poller_task = asyncio.create_task(self._poll_loop())
# kick off the background thread
t = threading.Thread(target=self._poll_loop, daemon=True)
t.start()
def _get_current_timestep(self):
"""
Get the current timestep from the PyReason logic program.
:return: Current timestep
"""
if self.logic_program is not None and self.logic_program.interp is not None:
interp = self.logic_program.interp
t = interp.time
return t
elif pr.get_logic_program() is not None and pr.get_logic_program().interp is not None:
self.logic_program = pr.get_logic_program()
interp = self.logic_program.interp
t = interp.time
return t
else:
# raise ValueError("No PyReason logic program provided.")
return -1
def _poll_loop(self) -> None:
"""
Background async loop that polls every self.poll_interval.
"""
# if self.logic_program is None:
# raise ValueError("No logic program to add facts into.")
# check if we have a logic program yet or not
while True:
current_time = self._get_current_timestep()
if current_time != -1:
# determine mode
if isinstance(self.poll_interval, timedelta):
interval_secs = self.poll_interval.total_seconds()
while True:
time.sleep(interval_secs)
current_time = self._get_current_timestep()
t1 = current_time + 1
t2 = t1
if self.poll_condition:
if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")):
continue
x = self.input_fn()
_, _, facts = self.forward(x, t1, t2)
for f in facts:
pr.add_fact(f)
# run the reasoning
pr.reason(again=True, restart=False)
else:
step_interval = self.poll_interval
last_step = current_time + 1
while True:
# wait until enough timesteps have passed
while self._get_current_timestep() - last_step < step_interval:
time.sleep(0.01)
current = self._get_current_timestep()
t1, t2 = current, current
last_step = current
if self.poll_condition:
if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")):
continue
x = self.input_fn()
_, _, facts = self.forward(x, t1, t2)
for f in facts:
pr.add_fact(f)
# run the reasoning
pr.reason(again=True, restart=False)
[docs]
def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
"""
Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the
predicate corresponding to the model name.
:param t1: Start time for the facts
:param t2: End time for the facts
:return: List of PyReason facts
"""
facts = []
for c in self.class_names:
fact = Fact(f'{c}({self.identifier})', name=f'{self.identifier}-{c}-fact', start_time=t1, end_time=t2)
facts.append(fact)
return facts
def _infer(self, x: torch.Tensor) -> torch.Tensor:
"""
Run the underlying model to get raw logits [N, C].
"""
return self.model(x)
def _postprocess(self, raw_output: torch.Tensor) -> torch.Tensor:
"""
raw_output should be a [N, C] logits tensor. Assert C == len(class_names),
then apply softmax over dim=1 → [N, C] probabilities.
"""
logits = raw_output
if logits.dim() != 2 or logits.size(1) != len(self.class_names):
raise ValueError(
f"Expected logits of shape [N, C] with C={len(self.class_names)}, got {tuple(logits.shape)}"
)
return F.softmax(logits, dim=1)
def _pred_to_facts(
self,
raw_output: torch.Tensor,
probabilities: torch.Tensor,
t1: int,
t2: int
) -> List[Fact]:
"""
Given a [N, C] probability tensor, build a flat List[Fact],
using threshold, snap_value, set_lower_bound, set_upper_bound.
Returns N * C facts.
"""
opts = self.interface_options
prob = probabilities # [N, C]
# Build a threshold tensor
threshold = torch.tensor(opts.threshold, dtype=prob.dtype, device=prob.device)
condition = prob > threshold # [N, C] boolean mask
# Determine lower/upper for “true” entries
if opts.snap_value is not None:
snap_val = torch.tensor(opts.snap_value, dtype=prob.dtype, device=prob.device)
lower_if_true = (snap_val if opts.set_lower_bound
else torch.tensor(0.0, dtype=prob.dtype, device=prob.device))
upper_if_true = (snap_val if opts.set_upper_bound
else torch.tensor(1.0, dtype=prob.dtype, device=prob.device))
else:
lower_if_true = prob if opts.set_lower_bound else torch.zeros_like(prob)
upper_if_true = prob if opts.set_upper_bound else torch.ones_like(prob)
zeros = torch.zeros_like(prob)
ones = torch.ones_like(prob)
lower_bounds = torch.where(condition, lower_if_true, zeros) # [N, C]
upper_bounds = torch.where(condition, upper_if_true, ones) # [N, C]
N, C = prob.shape
all_facts: List[Fact] = []
for i in range(N):
for j, class_name in enumerate(self.class_names):
lower_val = lower_bounds[i, j].item()
upper_val = upper_bounds[i, j].item()
fact_str = f"{class_name}({self.identifier}) : [{lower_val:.3f}, {upper_val:.3f}]"
fact_name = f"{self.identifier}-{class_name}-fact"
f = Fact(fact_str, name=fact_name, start_time=t1, end_time=t2)
all_facts.append(f)
return all_facts