pyreason.scripts.learning.classification.classifier
Module Contents
- class LogicIntegratedClassifier(model: torch.nn.Module, class_names: List[str], identifier: str = 'classifier', interface_options: pyreason.scripts.learning.utils.model_interface.ModelInterfaceOptions = None)[source]
Bases:
pyreason.scripts.learning.classification.logic_integration_base.LogicIntegrationBaseClass to integrate a PyTorch model with PyReason. The output of the model is returned to the user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them. Wraps any torch.nn.Module whose forward(x) returns [N, C] logits (multi-class). Implements _infer, _postprocess, and _pred_to_facts to replace the original forward().