1from __future__ import annotations
2
3from typing import TYPE_CHECKING
4
5from .syntax import And, App, EVar, MLQuant, Top
6
7if TYPE_CHECKING:
8 from collections.abc import Collection, Mapping
9
10 from .syntax import Pattern
11
12
[docs]
13def conjuncts(pattern: Pattern) -> tuple[Pattern, ...]:
14 if isinstance(pattern, Top):
15 return ()
16 if isinstance(pattern, And):
17 return tuple(conjunct for op in pattern.ops for conjunct in conjuncts(op))
18 return (pattern,)
19
20
[docs]
21def free_occs(pattern: Pattern, *, bound_vars: Collection[str] = ()) -> dict[str, list[EVar]]:
22 occurrences: dict[str, list[EVar]] = {}
23
24 def collect(pattern: Pattern, bound_vars: set[str]) -> None:
25 if type(pattern) is EVar and pattern.name not in bound_vars:
26 if pattern.name in occurrences:
27 occurrences[pattern.name].append(pattern)
28 else:
29 occurrences[pattern.name] = [pattern]
30
31 elif isinstance(pattern, MLQuant):
32 new_bound_vars = {pattern.var.name}.union(bound_vars)
33 collect(pattern.pattern, new_bound_vars)
34
35 else:
36 for sub_pattern in pattern.patterns:
37 collect(sub_pattern, bound_vars)
38
39 collect(pattern, set(bound_vars))
40 return occurrences
41
42
[docs]
43def collect_symbols(pattern: Pattern) -> set[str]:
44 """Return the set of all symbols referred to in a pattern.
45
46 Args:
47 pattern: Pattern to collect symbols from.
48 """
49 res: set[str] = set()
50
51 def add_symbol(pattern: Pattern) -> None:
52 match pattern:
53 case App(symbol):
54 res.add(symbol)
55
56 pattern.collect(add_symbol)
57 return res
58
59
[docs]
60def substitute_vars(pattern: Pattern, subst_map: Mapping[EVar, Pattern]) -> Pattern:
61 """Substitute variables in a pattern using a bottom-up traversal.
62
63 Args:
64 pattern: The pattern containing variables to be substituted.
65 subst_map: A mapping from variables to their replacement patterns.
66 """
67
68 def subst(pattern: Pattern) -> Pattern:
69 match pattern:
70 case EVar() as var:
71 return subst_map.get(var, var)
72 case _:
73 return pattern
74
75 return pattern.bottom_up(subst)
76
77
[docs]
78def elim_aliases(pattern: Pattern) -> Pattern:
79 r"""Eliminate subpatterns of the form ``\and{S}(p, X : S)``.
80
81 Both the ``\and`` and instances of ``X : S`` are replaced by the definition ``p``.
82 """
83 aliases = {}
84
85 def inline_aliases(pattern: Pattern) -> Pattern:
86 match pattern:
87 case And(_, (p, EVar() as var)):
88 aliases[var] = p
89 return p
90 case _:
91 return pattern
92
93 pattern = pattern.bottom_up(inline_aliases)
94 pattern = substitute_vars(pattern, aliases)
95 return pattern