Skip to content

API Reference

Construct a shield factory.

Construction modes: 1. Path + property (builds model and runs model checking) 2. Pre-built model + prism program + property (skips model build, runs model checking) 3. Pre-built model + prism program + check result (skips model build and model checking)

Source code in tempestpy/shielding/shield_factory.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class ShieldFactory:
    """
    Construct a shield factory.

    Construction modes:
    1. Path + property  (builds model and runs model checking)
    2. Pre-built model + prism program + property  (skips model build, runs model checking)
    3. Pre-built model + prism program + check result  (skips model build and model checking)
    """

    def __init__(
        self,
        model: Any,
        *,
        property: Optional[str] = None,
        constants: Optional[Dict[str, Any]] = None,
        prism_program: Optional[Any] = None,
        check_result: Optional[Any] = None,
        engine: str = "sparse",
    ) -> None:
        self._property = property
        self._check_result = check_result
        if engine not in {"sparse", "dd"}:
            raise ValueError(f"Selected engine {engine} is not supported")
        self._sparse = engine == "sparse"

        if isinstance(model, str):
            # --- Mode 1: path string ---
            logger.trace(f"ShieldFactory.__init__: mode 1 (path), {model=}")
            self._model_path = str(model)
            self._prism_program = self._parse_prism_program(self._model_path, constants)
            if property is not None:
                self._prepare_property(property)
            self._model = self._build_model_from_prism()
        else:
            # --- Mode 2 / 3: pre-built model ---
            logger.trace("ShieldFactory.__init__: mode 2/3 (pre-built model)")
            if prism_program is None:
                raise TypeError(
                    "prism_program is required when passing a pre-built model. "
                    "Use shield.prism_program to obtain it from an existing ShieldFactory, "
                    "or pass a path string."
                )
            # Accept either a parsed program or a path string
            if isinstance(prism_program, str):
                self._prism_program = self._parse_prism_program(prism_program, constants)
            else:
                self._prism_program = prism_program

            self._model = model
            if property is not None:
                self._prepare_property(property)
            # Store labeling references for consistency where sparse shielding needs them.
            if self.is_sparse:
                self._state_valuations = model.state_valuations
                self._choice_labeling = model.choice_labeling

        self._assert_supported_model(self._model)
        if property is not None:
            self._assert_supported_property(property)

    # ---- Alternative constructors ----

    @classmethod
    def from_check_result(
        cls,
        model: Any,
        *,
        prism_program: Any,
        result: Any,
        engine: str = "sparse",
    ) -> "ShieldFactory":
        """
        Build a ShieldFactory from a pre-built model and an already-computed check result.
        No model building or model checking will be performed; compute_mask can be
        called directly.

        prism_program may be a parsed storm::prism::Program or a path string.

        Example::

            shield2 = ShieldFactory.from_check_result(
                shield.model,
                prism_program=shield.prism_program,
                result=my_result,
            )
            masked = shield2.compute_mask(ShieldConfig(...))
        """
        logger.trace("ShieldFactory.from_check_result called")
        return cls(
            model,
            prism_program=prism_program,
            check_result=result,
            engine=engine,
        )

    # ---- Public properties ----

    @property
    def model(self) -> Any:
        return self._model

    @property
    def prism_program(self) -> Any:
        return self._prism_program

    @property
    def prop(self) -> Optional[str]:
        return self._property

    @prop.setter
    def prop(self, value: str) -> None:
        logger.trace(f"ShieldFactory.prop.setter: setting property to {value!r}")
        if value == self._property:
            logger.trace("ShieldFactory.prop.setter: property unchanged, skipping")
            return
        self._property = value
        self._prepare_property(value)
        # Invalidate cached results since the property changed
        self._check_result = None
        self._choice_values = None
        self._choice_offsets = None
        self._choice_counts = None

    @property
    def formula(self) -> Any:
        return getattr(self, "_formula", None)

    @property
    def optimality_type(self) -> Any:
        return getattr(self, "_optimality_type", None)

    # ---- Build ----

    def build(self, config: ShieldConfig) -> Any:
        """
        Run model checking (if not already done) and compute the action mask.
        """
        logger.trace(f"ShieldFactory.build called with {config=}")
        if self._check_result is None and getattr(self, "_formula", None) is None:
            raise RuntimeError(
                "Cannot build shield: no formula and no check result available. "
                "Pass a property string or use ShieldFactory.from_check_result."
            )
        if self._check_result is None:
            self._run_model_check()
        return self.compute_mask(config)

    def compute_mask(self, config: ShieldConfig) -> Any:
        """
        Compute the action mask from already-extracted choice values.
        Calls _extract_choice_values() if not already done.
        """
        logger.trace(f"ShieldFactory.compute_mask called with {config=}")
        if self._check_result is None:
            raise RuntimeError(
                "Cannot compute mask: no check result available. "
                "Call build() first or use ShieldFactory.from_check_result."
            )
        return compute_action_mask(self, config)

    # ---- Prism / model building internals ----

    @staticmethod
    def _parse_prism_program(path: str, constants: Optional[Dict[str, Any]]) -> Any:
        logger.trace(f"ShieldFactory._parse_prism_program: {path=}, {constants=}")
        program = stormpy.parse_prism_program(path)
        if constants:
            definition_string = _build_constant_definitions(constants)
            logger.info(f"defined constants: {definition_string}")
            constant_map = stormpy.parse_constants_string(
                program.expression_manager, definition_string
            )
            program = program.define_constants(constant_map)
        return program

    def _build_model_from_prism(self) -> Any:
        logger.trace("ShieldFactory._build_model_from_prism called")
        if self.formula is not None:
            options = stormpy.BuilderOptions([self.formula.raw_formula])
        else:
            options = stormpy.BuilderOptions(True, True)
        options.set_build_state_valuations()
        options.set_build_choice_labels()
        if self.is_sparse:
            model = stormpy.build_sparse_model_with_options(self._prism_program, options)
            self._state_valuations = model.state_valuations
            self._choice_labeling = model.choice_labeling
        else:
            properties = [self.formula] if self.formula is not None else None
            model = stormpy.build_symbolic_model(self._prism_program, properties)
        return model

    def _prepare_property(self, property: str) -> None:
        logger.trace(f"ShieldFactory._prepare_property: {property!r}")
        formula = stormpy.parse_properties_for_prism_program(
            _sanitize_property(property), self._prism_program
        )
        assert len(formula) == 1, f"Expected exactly one property, got {len(formula)}"
        self._formula = formula[0]
        raw = formula[0].raw_formula
        # For GameFormula (SMGs) the optimality type sits on the inner operator
        if getattr(raw, "is_game_formula", False):
            self._optimality_type = raw.subformula.optimality_type
        else:
            self._optimality_type = raw.optimality_type

    # ---- Model checking internals ----

    def _run_model_check(self) -> None:
        if self.formula is None:
            raise RuntimeError(
                "Cannot run model checking: no property set. "
                "Assign a property via ShieldFactory.prop = '...'"
            )
        logger.trace(f"ShieldFactory._run_model_check: {str(self.formula.raw_formula)=}")
        logger.info(f"ShieldFactory._run_model_check: {str(self.formula.raw_formula)=}")
        env = Environment()
        env.modelchecker_environment.set_produce_choice_values()
        result = stormpy.model_checking(self._model, self._formula, environment=env)
        assert result.has_choice_values, (
            f"Choice values were not computed for '{self._formula.raw_formula}'"
        )
        self._check_result = result

    def _extract_choice_values(self) -> None:
        """
        Extract choice values, offsets and counts from the stored check result.
        Guarded: no-op if already extracted.
        """
        if getattr(self, "_choice_values", None) is not None:
            logger.trace("ShieldFactory._extract_choice_values: already extracted, skipping")
            return
        logger.trace("ShieldFactory._extract_choice_values: extracting from check result")
        result = self._check_result
        choice_vals = np.asarray(result.get_choice_values(), dtype=float)
        counts = np.fromiter(
            (len(s.actions) for s in self._model.states), dtype=int
        )
        offsets = np.concatenate(([0], np.cumsum(counts)))
        assert offsets[-1] == choice_vals.size, (
            f"Choice value count mismatch: {choice_vals.size} values vs "
            f"{offsets[-1]} expected from state action counts"
        )
        self._choice_values = choice_vals
        self._choice_offsets = offsets
        self._choice_counts = counts

    # ---- Assertions ----

    def _assert_supported_model(self, model: Any) -> None:
        if model is None:
            raise TypeError("model must not be None")
        if model.model_type not in {stormpy.ModelType.MDP, stormpy.ModelType.SMG}:
            raise TypeError("the model must be an MDP or SMG")
        if self.is_sparse:
            if not model.is_sparse_model:
                raise TypeError("engine='sparse' requires a sparse model")
            if not model.has_choice_labeling():
                raise TypeError("the model must be built with choice labeling")
            if not model.has_state_valuations():
                raise TypeError("the model must be built with state valuations")
        elif not model.is_symbolic_model:
            raise TypeError("engine='dd' requires a symbolic model")

    def _assert_supported_property(self, prop: Any) -> None:
        if prop is None:
            raise TypeError("property must not be None")

    @property
    def is_sparse(self) -> bool:
        return self._sparse

    # ---- State lookup ----

    @property
    def state_lookup(self) -> StateValuationLookup:
        if getattr(self, "_state_lookup", None) is None:
            self._state_lookup = StateValuationLookup.from_model(
                self._model,
                self._prism_program,
            )
        return self._state_lookup

    def get_choice_values_for_state(self, sid: int) -> np.ndarray:
        start = self._choice_offsets[sid]
        end = self._choice_offsets[sid + 1]
        return self._choice_values[start:end]

    def get_state_id(self, values: dict) -> int:
        return self.state_lookup.get_state_id(values)

    # ---- Dangerous / Critical / Safe / Unsafe helpers

    def _states_with_valuations(self, sids: list[int]) -> list[dict]:
        return [self.state_lookup.values_for_state_id(sid) for sid in sids]

    def dangerous_states(self, *, mask: ShieldResult) -> list[dict]:
        return self._states_with_valuations(mask.dangerous_states())

    def critical_states(self, *, mask: ShieldResult) -> list[dict]:
        return self._states_with_valuations(mask.critical_states())

    def safe_states(self, *, mask: ShieldResult) -> list[dict]:
        return self._states_with_valuations(mask.safe_states())

    def states_by_class(self, *, mask: ShieldResult) -> dict[str, list[dict]]:
        raw = mask.states_by_class()
        return {cls: self._states_with_valuations(sids) for cls, sids in raw.items()}

    # ---- Export helpers ----

    def to_bitmask(self, mask: Any) -> Any:
        return export_mod.state_id_to_bitmask(mask)

    def to_allowed_action_indices(self, mask: Any) -> Any:
        return export_mod.state_id_to_allowed_action_indices(self._model, mask)

    def to_allowed_action_labels(self, mask: Any) -> Any:
        return export_mod.state_id_to_allowed_action_labels(self._model, mask)

    def to_allowed_state_action_pairs(self, mask: Any) -> Any:
        return export_mod.allowed_state_action_pairs(self._model, mask)

    def to_valuation_bitmask(self, mask: Any) -> Any:
        return export_mod.valuation_to_bitmask(
            self._model,
            mask,
            self.state_lookup,
        )

    def to_valuation_allowed_action_labels(self, mask: Any) -> Any:
        return export_mod.valuation_to_allowed_action_labels(
            self._model,
            mask,
            self.state_lookup,
        )

    def to_valuation_action_label_probability(self) -> dict:
        """
        { valuation_key(state) -> { frozenset(action_labels): probability } }
        for ALL actions, not just shield-allowed ones.
        """
        return export_mod.valuation_to_action_label_probability(
            self._model,
            self.state_lookup,
            lambda sid, i: float(self.get_choice_values_for_state(sid)[i]),
        )

    def pretty_mask(
        self,
        mask: Any,
        *,
        max_states: int = 25,
        show_valuations: bool = True,
        action_ref: str = "labels",
    ) -> str:
        return export_mod.pretty(
            self._model,
            mask,
            self.state_lookup,
            max_states=max_states,
            show_valuations=show_valuations,
            action_ref=action_ref,
        )

    def to_storm_format(self, mask: Any) -> str:
        return export_mod.storm_format(self, mask)

build(config)

Run model checking (if not already done) and compute the action mask.

Source code in tempestpy/shielding/shield_factory.py
171
172
173
174
175
176
177
178
179
180
181
182
183
def build(self, config: ShieldConfig) -> Any:
    """
    Run model checking (if not already done) and compute the action mask.
    """
    logger.trace(f"ShieldFactory.build called with {config=}")
    if self._check_result is None and getattr(self, "_formula", None) is None:
        raise RuntimeError(
            "Cannot build shield: no formula and no check result available. "
            "Pass a property string or use ShieldFactory.from_check_result."
        )
    if self._check_result is None:
        self._run_model_check()
    return self.compute_mask(config)

compute_mask(config)

Compute the action mask from already-extracted choice values. Calls _extract_choice_values() if not already done.

Source code in tempestpy/shielding/shield_factory.py
185
186
187
188
189
190
191
192
193
194
195
196
def compute_mask(self, config: ShieldConfig) -> Any:
    """
    Compute the action mask from already-extracted choice values.
    Calls _extract_choice_values() if not already done.
    """
    logger.trace(f"ShieldFactory.compute_mask called with {config=}")
    if self._check_result is None:
        raise RuntimeError(
            "Cannot compute mask: no check result available. "
            "Call build() first or use ShieldFactory.from_check_result."
        )
    return compute_action_mask(self, config)

from_check_result(model, *, prism_program, result, engine='sparse') classmethod

Build a ShieldFactory from a pre-built model and an already-computed check result. No model building or model checking will be performed; compute_mask can be called directly.

prism_program may be a parsed storm::prism::Program or a path string.

Example::

shield2 = ShieldFactory.from_check_result(
    shield.model,
    prism_program=shield.prism_program,
    result=my_result,
)
masked = shield2.compute_mask(ShieldConfig(...))
Source code in tempestpy/shielding/shield_factory.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@classmethod
def from_check_result(
    cls,
    model: Any,
    *,
    prism_program: Any,
    result: Any,
    engine: str = "sparse",
) -> "ShieldFactory":
    """
    Build a ShieldFactory from a pre-built model and an already-computed check result.
    No model building or model checking will be performed; compute_mask can be
    called directly.

    prism_program may be a parsed storm::prism::Program or a path string.

    Example::

        shield2 = ShieldFactory.from_check_result(
            shield.model,
            prism_program=shield.prism_program,
            result=my_result,
        )
        masked = shield2.compute_mask(ShieldConfig(...))
    """
    logger.trace("ShieldFactory.from_check_result called")
    return cls(
        model,
        prism_program=prism_program,
        check_result=result,
        engine=engine,
    )

to_valuation_action_label_probability()

{ valuation_key(state) -> { frozenset(action_labels): probability } } for ALL actions, not just shield-allowed ones.

Source code in tempestpy/shielding/shield_factory.py
375
376
377
378
379
380
381
382
383
384
def to_valuation_action_label_probability(self) -> dict:
    """
    { valuation_key(state) -> { frozenset(action_labels): probability } }
    for ALL actions, not just shield-allowed ones.
    """
    return export_mod.valuation_to_action_label_probability(
        self._model,
        self.state_lookup,
        lambda sid, i: float(self.get_choice_values_for_state(sid)[i]),
    )
Source code in tempestpy/shielding/masking.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@dataclass(frozen=True, slots=True)
class ShieldConfig:
    threshold: float = 1.0
    comparison: ComparisonMode = "absolute"

    post_selector: Optional[PostSelectorFn] = None
    action_lookup: Optional[ActionLookup] = None

    _bitmask_view: BitmaskView = field(init=False)

    def __post_init__(self):
        object.__setattr__(
            self, "_bitmask_view",
            "global" if self.action_lookup is not None else "local"
        )

bitmask_by_state[sid]: integer bitmask where bit i corresponds to: - local mode: the local action index i in model.states[sid].actions - global mode: the global RL action index i from ActionLookup

fallback_by_state[sid]: non-zero only for dangerous states (no action meets threshold). equals bitmask_by_state[sid] in that case, argmax action promoted as last resort. zero for safe, critical, and unsafe states.

nr_actions_by_state[sid]: - local mode: len(state.actions) for state sid - global mode: ActionLookup.nr_actions (constant across all states)

best_value_by_state[sid]: the highest choice value across all actions in state sid.

State classification (derivable from bitmask + fallback + nr_actions): unsafe: bitmask == 0 and fallback == 0 (property violated) dangerous: bitmask > 0 and bitmask == fallback (no action meets threshold) critical: bitmask > 0 and fallback == 0 and popcount(bitmask) < nr_actions safe: bitmask > 0 and fallback == 0 and popcount(bitmask) == nr_actions

Source code in tempestpy/shielding/masking.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@dataclass(frozen=True, slots=True)
class ShieldResult:
    """
    bitmask_by_state[sid]:
      integer bitmask where bit i corresponds to:
        - local mode:  the local action index i in model.states[sid].actions
        - global mode: the global RL action index i from ActionLookup

    fallback_by_state[sid]:
      non-zero only for dangerous states (no action meets threshold).
      equals bitmask_by_state[sid] in that case, argmax action promoted
      as last resort. zero for safe, critical, and unsafe states.

    nr_actions_by_state[sid]:
      - local mode:  len(state.actions) for state sid
      - global mode: ActionLookup.nr_actions (constant across all states)

    best_value_by_state[sid]:
      the highest choice value across all actions in state sid.

    State classification (derivable from bitmask + fallback + nr_actions):
      unsafe:    bitmask == 0 and fallback == 0   (property violated)
      dangerous: bitmask > 0 and bitmask == fallback  (no action meets threshold)
      critical:  bitmask > 0 and fallback == 0 and popcount(bitmask) < nr_actions
      safe:      bitmask > 0 and fallback == 0 and popcount(bitmask) == nr_actions
    """
    config: ShieldConfig
    bitmask_by_state: list[int]
    fallback_by_state: list[int]
    nr_actions_by_state: list[int]
    best_value_by_state: Optional[np.ndarray] = None
    bit_width: Optional[int] = None
    # Non-None only for SMGs: frozenset of state ids belonging to the ego coalition.
    # States absent from this set are adversary/environment states with no mask.
    coalition_states: Optional[frozenset] = None

    # ------------------------------------------------------------------
    # Hot-path helpers (RL loop)
    # ------------------------------------------------------------------

    def query_mask(self, sid: int) -> int:
        """Pre-shielding: return the bitmask for the RL wrapper to apply."""
        bits = self.bitmask_by_state[sid]
        if bits != 0:
            return bits
        return self.fallback_by_state[sid]

    def query_action(self, sid: int, agent_action: int, pvals: np.ndarray, **kwargs) -> int:
        """Post-shielding: return agent_action if safe, else invoke post_selector."""
        if self.bitmask_by_state[sid] & (1 << agent_action):
            return agent_action
        return self.config.post_selector(pvals, **kwargs)

    def is_action_allowed(self, sid: int, action_index: int) -> bool:
        return bool(self.bitmask_by_state[sid] & (1 << action_index))

    def to_bitmask_dict(self) -> dict[int, int]:
        return {sid: bits for sid, bits in enumerate(self.bitmask_by_state)}

    # ------------------------------------------------------------------
    # State classification helpers (offline / debug)
    # ------------------------------------------------------------------

    def is_coalition_state(self, sid: int) -> bool:
        """True for all states in non-SMG models; for SMGs, True only for ego-coalition states."""
        return self.coalition_states is None or sid in self.coalition_states

    def is_unsafe(self, sid: int) -> bool:
        """property violated, no action available at all."""
        if self.coalition_states is not None and sid not in self.coalition_states:
            return False
        return self.bitmask_by_state[sid] == 0 and self.fallback_by_state[sid] == 0

    def is_dangerous(self, sid: int) -> bool:
        """No action meets threshold."""
        return self.fallback_by_state[sid] != 0

    def is_critical(self, sid: int) -> bool:
        """Some actions blocked by threshold."""
        bits = self.bitmask_by_state[sid]
        return (
            bits > 0
            and self.fallback_by_state[sid] == 0
            and bits.bit_count() < self.nr_actions_by_state[sid]
        )

    def is_safe(self, sid: int) -> bool:
        """All actions permitted."""
        return (
            self.fallback_by_state[sid] == 0
            and self.bitmask_by_state[sid].bit_count() == self.nr_actions_by_state[sid]
        )

    def classify_state(self, sid: int) -> str:
        """One of 'safe' | 'critical' | 'dangerous' | 'unsafe'."""
        if self.is_unsafe(sid):    return "unsafe"
        if self.is_dangerous(sid): return "dangerous"
        if self.is_critical(sid):  return "critical"
        return "safe"

    def states_by_class(self) -> dict[str, list[int]]:
        """Return {class_name: [sid, ...]} for all four classes."""
        out = {"safe": [], "critical": [], "dangerous": [], "unsafe": []}
        for sid in range(len(self)):
            out[self.classify_state(sid)].append(sid)
        return out

    def dangerous_states(self) -> list[int]:
        return [s for s in range(len(self)) if self.is_dangerous(s)]

    def critical_states(self) -> list[int]:
        return [s for s in range(len(self)) if self.is_critical(s)]

    def safe_states(self) -> list[int]:
        return [s for s in range(len(self)) if self.is_safe(s)]

    # ------------------------------------------------------------------
    # Dunder
    # ------------------------------------------------------------------
    def __len__(self) -> int:
        return len(self.bitmask_by_state)

classify_state(sid)

One of 'safe' | 'critical' | 'dangerous' | 'unsafe'.

Source code in tempestpy/shielding/masking.py
128
129
130
131
132
133
def classify_state(self, sid: int) -> str:
    """One of 'safe' | 'critical' | 'dangerous' | 'unsafe'."""
    if self.is_unsafe(sid):    return "unsafe"
    if self.is_dangerous(sid): return "dangerous"
    if self.is_critical(sid):  return "critical"
    return "safe"

is_coalition_state(sid)

True for all states in non-SMG models; for SMGs, True only for ego-coalition states.

Source code in tempestpy/shielding/masking.py
 98
 99
100
def is_coalition_state(self, sid: int) -> bool:
    """True for all states in non-SMG models; for SMGs, True only for ego-coalition states."""
    return self.coalition_states is None or sid in self.coalition_states

is_critical(sid)

Some actions blocked by threshold.

Source code in tempestpy/shielding/masking.py
112
113
114
115
116
117
118
119
def is_critical(self, sid: int) -> bool:
    """Some actions blocked by threshold."""
    bits = self.bitmask_by_state[sid]
    return (
        bits > 0
        and self.fallback_by_state[sid] == 0
        and bits.bit_count() < self.nr_actions_by_state[sid]
    )

is_dangerous(sid)

No action meets threshold.

Source code in tempestpy/shielding/masking.py
108
109
110
def is_dangerous(self, sid: int) -> bool:
    """No action meets threshold."""
    return self.fallback_by_state[sid] != 0

is_safe(sid)

All actions permitted.

Source code in tempestpy/shielding/masking.py
121
122
123
124
125
126
def is_safe(self, sid: int) -> bool:
    """All actions permitted."""
    return (
        self.fallback_by_state[sid] == 0
        and self.bitmask_by_state[sid].bit_count() == self.nr_actions_by_state[sid]
    )

is_unsafe(sid)

property violated, no action available at all.

Source code in tempestpy/shielding/masking.py
102
103
104
105
106
def is_unsafe(self, sid: int) -> bool:
    """property violated, no action available at all."""
    if self.coalition_states is not None and sid not in self.coalition_states:
        return False
    return self.bitmask_by_state[sid] == 0 and self.fallback_by_state[sid] == 0

query_action(sid, agent_action, pvals, **kwargs)

Post-shielding: return agent_action if safe, else invoke post_selector.

Source code in tempestpy/shielding/masking.py
82
83
84
85
86
def query_action(self, sid: int, agent_action: int, pvals: np.ndarray, **kwargs) -> int:
    """Post-shielding: return agent_action if safe, else invoke post_selector."""
    if self.bitmask_by_state[sid] & (1 << agent_action):
        return agent_action
    return self.config.post_selector(pvals, **kwargs)

query_mask(sid)

Pre-shielding: return the bitmask for the RL wrapper to apply.

Source code in tempestpy/shielding/masking.py
75
76
77
78
79
80
def query_mask(self, sid: int) -> int:
    """Pre-shielding: return the bitmask for the RL wrapper to apply."""
    bits = self.bitmask_by_state[sid]
    if bits != 0:
        return bits
    return self.fallback_by_state[sid]

states_by_class()

Return {class_name: [sid, ...]} for all four classes.

Source code in tempestpy/shielding/masking.py
135
136
137
138
139
140
def states_by_class(self) -> dict[str, list[int]]:
    """Return {class_name: [sid, ...]} for all four classes."""
    out = {"safe": [], "critical": [], "dangerous": [], "unsafe": []}
    for sid in range(len(self)):
        out[self.classify_state(sid)].append(sid)
    return out

Bases: Wrapper

Gymnasium wrapper that adds shield-based action masking.

Requires
  • env.action_space is gym.spaces.Discrete
  • obs_to_values maps obs -> {prism_var: value} for every PRISM state variable

After reset() / step(), info contains: "action_mask" -> np.ndarray[bool], shape (n_actions,) "shield_state_id" -> int "shield_bitmask" -> int

action_masks() returns the current mask for SB3 MaskablePPO.

Source code in tempestpy/shielding/wrappers.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class PreShieldWrapper(gym.Wrapper):
    """
    Gymnasium wrapper that adds shield-based action masking.

    Requires:
      - env.action_space is gym.spaces.Discrete
      - obs_to_values maps obs -> {prism_var: value} for every PRISM state variable

    After reset() / step(), info contains:
      "action_mask"     -> np.ndarray[bool], shape (n_actions,)
      "shield_state_id" -> int
      "shield_bitmask"  -> int

    action_masks() returns the current mask for SB3 MaskablePPO.
    """

    def __init__(
        self,
        env: gym.Env,
        factory: ShieldFactory,
        config: ShieldConfig,
        *,
        obs_to_values: ObsToValuesFn,
    ) -> None:
        super().__init__(env)
        if not isinstance(env.action_space, gym.spaces.Discrete):
            raise TypeError("PreShieldWrapper requires a Discrete action space.")

        self._factory = factory
        self._config = config
        self._obs_to_values = obs_to_values
        self._n_actions = int(env.action_space.n)

        self._result: ShieldResult = factory.build(config)
        self._last_mask: Optional[np.ndarray] = None

    def _query(self, obs: Any, info: Optional[dict] = None) -> tuple[int, np.ndarray, int]:
        values = self._obs_to_values(obs, info)
        sid = self._factory.state_lookup.get_state_id(values)
        bits = self._result.query_mask(sid)
        mask = _bitmask_to_bool_array(bits, self._n_actions)
        return sid, mask, bits

    def _augment(self, info: Optional[dict], sid: int, mask: np.ndarray, bits: int) -> dict:
        out = {} if info is None else dict(info)
        out["action_mask"] = mask
        out["shield_state_id"] = sid
        out["shield_bitmask"] = bits
        return out

    def action_masks(self) -> np.ndarray:
        """SB3 MaskablePPO hook — returns current boolean action mask."""
        if self._last_mask is None:
            raise RuntimeError("action_masks() called before reset().")
        return self._last_mask

    def rebuild(self, config: ShieldConfig) -> None:
        """Recompute the shield with a new config (e.g. different threshold)."""
        self._config = config
        self._result = self._factory.build(config)
        self._last_mask = None

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        sid, mask, bits = self._query(obs, info)
        self._last_mask = mask
        return obs, self._augment(info, sid, mask, bits)

    def step(self, action: int):
        if self._last_mask is None:
            raise RuntimeError("step() called before reset().")
        obs, reward, terminated, truncated, info = self.env.step(action)
        sid, mask, bits = self._query(obs, info)
        self._last_mask = mask
        return obs, reward, terminated, truncated, self._augment(info, sid, mask, bits)

action_masks()

SB3 MaskablePPO hook — returns current boolean action mask.

Source code in tempestpy/shielding/wrappers.py
79
80
81
82
83
def action_masks(self) -> np.ndarray:
    """SB3 MaskablePPO hook — returns current boolean action mask."""
    if self._last_mask is None:
        raise RuntimeError("action_masks() called before reset().")
    return self._last_mask

rebuild(config)

Recompute the shield with a new config (e.g. different threshold).

Source code in tempestpy/shielding/wrappers.py
85
86
87
88
89
def rebuild(self, config: ShieldConfig) -> None:
    """Recompute the shield with a new config (e.g. different threshold)."""
    self._config = config
    self._result = self._factory.build(config)
    self._last_mask = None

Bases: Wrapper

Gymnasium wrapper that enforces shield safety via post-shielding.

The agent receives the unmasked observation and selects any action. Before env.step the wrapper calls query_post(sid, action): if the action is allowed it passes through unchanged; if blocked the post_selector from ShieldConfig is invoked to supply a replacement.

Requires
  • env.action_space is gym.spaces.Discrete
  • ShieldConfig.post_selector is set (query_post raises otherwise)
  • obs_to_values maps obs -> {prism_var: value} for every PRISM variable

After reset() / step(), info contains: "shield_state_id" -> int state id used for the next step "shield_safe_action" -> int action actually passed to env.step "shield_corrected" -> bool True when the agent's action was replaced

Source code in tempestpy/shielding/wrappers.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class PostShieldWrapper(gym.Wrapper):
    """
    Gymnasium wrapper that enforces shield safety via post-shielding.

    The agent receives the unmasked observation and selects any action.
    Before env.step the wrapper calls query_post(sid, action): if the action
    is allowed it passes through unchanged; if blocked the post_selector
    from ShieldConfig is invoked to supply a replacement.

    Requires:
      - env.action_space is gym.spaces.Discrete
      - ShieldConfig.post_selector is set (query_post raises otherwise)
      - obs_to_values maps obs -> {prism_var: value} for every PRISM variable

    After reset() / step(), info contains:
      "shield_state_id"    -> int   state id used for the *next* step
      "shield_safe_action" -> int   action actually passed to env.step
      "shield_corrected"   -> bool  True when the agent's action was replaced
    """

    def __init__(
        self,
        env: gym.Env,
        factory: ShieldFactory,
        config: ShieldConfig,
        *,
        obs_to_values: ObsToValuesFn,
    ) -> None:
        super().__init__(env)
        if not isinstance(env.action_space, gym.spaces.Discrete):
            raise TypeError("PostShieldWrapper requires a Discrete action space.")
        if config.post_selector is None:
            raise ValueError(
                "PostShieldWrapper requires ShieldConfig.post_selector to be set."
            )

        self._factory = factory
        self._config = config
        self._obs_to_values = obs_to_values

        self._result: ShieldResult = factory.build(config)
        self._last_sid: Optional[int] = None

    def _obs_to_sid(self, obs: Any, info: Optional[dict] = None) -> int:
        values = self._obs_to_values(obs, info)
        return self._factory.state_lookup.get_state_id(values)

    def rebuild(self, config: ShieldConfig) -> None:
        """Recompute the shield with a new config (e.g. different threshold)."""
        if config.post_selector is None:
            raise ValueError("post_selector must be set when rebuilding PostShieldWrapper.")
        self._config = config
        self._result = self._factory.build(config)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._last_sid = self._obs_to_sid(obs, info)
        if info is None:
            info = {}
        info["shield_state_id"] = self._last_sid
        return obs, info

    def step(self, action: int):
        if self._last_sid is None:
            raise RuntimeError("step() called before reset().")

        safe_action = self._result.query_post(self._last_sid, action)
        obs, reward, terminated, truncated, info = self.env.step(safe_action)

        self._last_sid = self._obs_to_sid(obs, info)
        if info is None:
            info = {}
        info["shield_state_id"] = self._last_sid
        info["shield_safe_action"] = safe_action
        info["shield_corrected"] = safe_action != action
        return obs, reward, terminated, truncated, info

rebuild(config)

Recompute the shield with a new config (e.g. different threshold).

Source code in tempestpy/shielding/wrappers.py
244
245
246
247
248
249
def rebuild(self, config: ShieldConfig) -> None:
    """Recompute the shield with a new config (e.g. different threshold)."""
    if config.post_selector is None:
        raise ValueError("post_selector must be set when rebuilding PostShieldWrapper.")
    self._config = config
    self._result = self._factory.build(config)