Source code for pyreason.scripts.learning.classification.hf_classifier

from typing import List, Any

import torch
import torch.nn.functional as F

from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions


[docs] class HuggingFaceLogicIntegratedClassifier(LogicIntegrationBase): """ Integrates a HuggingFace image classification model with PyReason. Extends LogicIntegrationBase by implementing _infer, _postprocess, and _pred_to_facts. """ def __init__( self, model, class_names: List[str], identifier: str = 'hf_classifier', interface_options: ModelInterfaceOptions = None, limit_classes: bool = False ): """ :param model: A HuggingFace model (e.g. AutoModelForImageClassification). :param class_names: List of class names for the model output. :param identifier: Identifier for the model, used as the constant in the facts. :param interface_options: Options for the model interface, including threshold and snapping behavior. :param limit_classes: If True, filter output probabilities to only the classes in class_names using the model's id2label config, renormalize, and reorder class_names by probability. """ super().__init__(model, class_names, interface_options, identifier)
[docs] self.limit_classes = limit_classes
def _infer(self, x: Any) -> Any: return self.model(**x).logits def _postprocess(self, raw_output: Any) -> Any: probabilities = F.softmax(raw_output, dim=1).squeeze() if self.limit_classes: probabilities, self._filtered_labels = self._filter_to_allowed_classes(probabilities) else: self._filtered_labels = None return probabilities def _filter_to_allowed_classes(self, probabilities: torch.Tensor): """Filter probabilities to only the allowed class_names using model.config.id2label. Returns (top_probs, top_labels) without mutating self.class_names.""" id2label = self.model.config.id2label allowed_indices = [ i for i, label in id2label.items() if label.split(",")[0].strip().lower() in [name.lower() for name in self.class_names] ] filtered_probs = torch.zeros_like(probabilities) filtered_probs[allowed_indices] = probabilities[allowed_indices] filtered_probs = filtered_probs / filtered_probs.sum() top_labels = [] top_probs, top_indices = filtered_probs.topk(len(self.class_names)) for idx in top_indices: label = id2label[idx.item()].split(",")[0] top_labels.append(label) return top_probs, top_labels def _pred_to_facts( self, raw_output: Any, probabilities: Any, t1: int = 0, t2: int = 0 ) -> List[Fact]: opts = self.interface_options threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device) condition = probabilities > threshold if opts.snap_value is not None: snap_value = torch.tensor(opts.snap_value, dtype=probabilities.dtype, device=probabilities.device) lower_val = snap_value if opts.set_lower_bound else torch.tensor(0.0, dtype=probabilities.dtype, device=probabilities.device) upper_val = snap_value if opts.set_upper_bound else torch.tensor(1.0, dtype=probabilities.dtype, device=probabilities.device) else: lower_val = probabilities if opts.set_lower_bound else torch.zeros_like(probabilities) upper_val = probabilities if opts.set_upper_bound else torch.ones_like(probabilities) lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities)) upper_bounds = torch.where(condition, upper_val, torch.ones_like(probabilities)) labels = self._filtered_labels if self._filtered_labels is not None else self.class_names facts = [] for i in range(len(labels)): lower = lower_bounds[i].item() upper = upper_bounds[i].item() fact_str = f'{labels[i]}({self.identifier}) : [{lower:.3f}, {upper:.3f}]' fact = Fact(fact_str, name=f'{self.identifier}-{labels[i]}-fact', start_time=t1, end_time=t2) facts.append(fact) return facts