1from __future__ import annotations
2
3from typing import TYPE_CHECKING, overload
4
5from ..dequote import bytes_encode
6from ..utils import case, check_type
7from .prelude import BOOL, BYTES, ID, INT, STRING
8from .syntax import DV, App, LeftAssoc
9
10if TYPE_CHECKING:
11 from collections.abc import Callable
12 from typing import Any, TypeVar
13
14 from .syntax import Pattern, Sort
15
16 T = TypeVar('T')
17 K = TypeVar('K')
18 V = TypeVar('V')
19
20
[docs]
21def match_dv(pattern: Pattern, sort: Sort | None = None) -> DV:
22 dv = check_type(pattern, DV)
23 if sort and dv.sort != sort:
24 raise ValueError(f'Expected sort {sort.text}, found: {dv.sort.text}')
25 return dv
26
27
[docs]
28def match_symbol(actual: str, expected: str) -> None:
29 if actual != expected:
30 raise ValueError(f'Expected symbol {expected}, found: {actual}')
31
32
[docs]
33def match_app(pattern: Pattern, symbol: str | None = None) -> App:
34 app = check_type(pattern, App)
35 if symbol is not None:
36 match_symbol(app.symbol, symbol)
37 return app
38
39
[docs]
40def match_inj(pattern: Pattern) -> App:
41 return match_app(pattern, 'inj')
42
43
[docs]
44def match_left_assoc(pattern: Pattern, symbol: str | None = None) -> LeftAssoc:
45 assoc = check_type(pattern, LeftAssoc)
46 if symbol is not None:
47 match_symbol(assoc.symbol, symbol)
48 return assoc
49
50
[docs]
51def match_list(pattern: Pattern) -> tuple[Pattern, ...]:
52 if type(pattern) is App:
53 match_app(pattern, "Lbl'Stop'List")
54 return ()
55
56 assoc = match_left_assoc(pattern, "Lbl'Unds'List'Unds'")
57 items = (match_app(arg, 'LblListItem') for arg in assoc.args)
58 elems = (item.args[0] for item in items)
59 return tuple(elems)
60
61
[docs]
62def match_set(pattern: Pattern) -> tuple[Pattern, ...]:
63 if type(pattern) is App:
64 match_app(pattern, "Lbl'Stop'Set")
65 return ()
66
67 assoc = match_left_assoc(pattern, "Lbl'Unds'Set'Unds'")
68 items = (match_app(arg, 'LblSetItem') for arg in assoc.args)
69 elems = (item.args[0] for item in items)
70 return tuple(elems)
71
72
[docs]
73def match_map(pattern: Pattern, *, cell: str | None = None) -> tuple[tuple[Pattern, Pattern], ...]:
74 cell = cell or ''
75 stop_symbol = f"Lbl'Stop'{cell}Map"
76 cons_symbol = f"Lbl'Unds'{cell}Map'Unds'"
77 item_symbol = "Lbl'UndsPipe'-'-GT-Unds'" if not cell else f'Lbl{cell}MapItem'
78
79 if type(pattern) is App:
80 match_app(pattern, stop_symbol)
81 return ()
82
83 assoc = match_left_assoc(pattern, cons_symbol)
84 items = (match_app(arg, item_symbol) for arg in assoc.args)
85 entries = ((item.args[0], item.args[1]) for item in items)
86 return tuple(entries)
87
88
[docs]
89def match_rangemap(pattern: Pattern) -> tuple[tuple[tuple[Pattern, Pattern], Pattern], ...]:
90 stop_symbol = "Lbl'Stop'RangeMap"
91 cons_symbol = "Lbl'Unds'RangeMap'Unds'"
92 item_symbol = "Lbl'Unds'r'Pipe'-'-GT-Unds'"
93
94 if type(pattern) is App:
95 match_app(pattern, stop_symbol)
96 return ()
97
98 assoc = match_left_assoc(pattern)
99 cons = match_app(assoc.app, cons_symbol)
100 items = (match_app(arg, item_symbol) for arg in cons.args)
101 entries = ((_match_range(item.args[0]), item.args[1]) for item in items)
102 return tuple(entries)
103
104
105def _match_range(pattern: Pattern) -> tuple[Pattern, Pattern]:
106 range_symbol = "LblRangeMap'Coln'Range"
107 range = match_app(pattern, range_symbol)
108 return (range.args[0], range.args[1])
109
110
[docs]
111def kore_bool(pattern: Pattern) -> bool:
112 dv = match_dv(pattern, BOOL)
113 match dv.value.value:
114 case 'true':
115 return True
116 case 'false':
117 return False
118 case _:
119 raise ValueError(f'Invalid Boolean domain value: {dv.text}')
120
121
[docs]
122def kore_int(pattern: Pattern) -> int:
123 dv = match_dv(pattern, INT)
124 return int(dv.value.value)
125
126
[docs]
127def kore_bytes(pattern: Pattern) -> bytes:
128 dv = match_dv(pattern, BYTES)
129 return bytes_encode(dv.value.value)
130
131
[docs]
132def kore_str(pattern: Pattern) -> str:
133 dv = match_dv(pattern, STRING)
134 return dv.value.value
135
136
[docs]
137def kore_id(pattern: Pattern) -> str:
138 dv = match_dv(pattern, ID)
139 return dv.value.value
140
141
142# Higher-order functions
143
144
[docs]
145def app(symbol: str | None = None) -> Callable[[Pattern], App]:
146 def res(pattern: Pattern) -> App:
147 return match_app(pattern, symbol)
148
149 return res
150
151
152@overload
153def arg(n: int, /) -> Callable[[App], Pattern]: ...
154
155
156@overload
157def arg(symbol: str, /) -> Callable[[App], App]: ...
158
159
[docs]
160def arg(id: int | str) -> Callable[[App], Pattern | App]:
161 def res(app: App) -> Pattern | App:
162 if type(id) is int:
163 if len(app.args) <= id:
164 raise ValueError('Argument index is out of range')
165
166 return app.args[id]
167
168 try:
169 arg, *_ = (arg for arg in app.args if type(arg) is App and arg.symbol == id)
170 except ValueError:
171 raise ValueError(f'No matching argument found for symbol: {id}') from None
172 return arg
173
174 return res
175
176
177@overload
178def args() -> Callable[[App], tuple[()]]: ...
179
180
181@overload
182def args(n1: int, /) -> Callable[[App], tuple[Pattern]]: ...
183
184
185@overload
186def args(n1: int, n2: int, /) -> Callable[[App], tuple[Pattern, Pattern]]: ...
187
188
189@overload
190def args(n1: int, n2: int, n3: int, /) -> Callable[[App], tuple[Pattern, Pattern, Pattern]]: ...
191
192
193@overload
194def args(n1: int, n2: int, n3: int, n4: int, /) -> Callable[[App], tuple[Pattern, Pattern, Pattern, Pattern]]: ...
195
196
197@overload
198def args(*ns: int) -> Callable[[App], tuple[Pattern, ...]]: ...
199
200
201@overload
202def args(s1: str, /) -> Callable[[App], tuple[App]]: ...
203
204
205@overload
206def args(s1: str, s2: str, /) -> Callable[[App], tuple[App, App]]: ...
207
208
209@overload
210def args(s1: str, s2: str, s3: str, /) -> Callable[[App], tuple[App, App, App]]: ...
211
212
213@overload
214def args(s1: str, s2: str, s3: str, s4: str, /) -> Callable[[App], tuple[App, App, App, App]]: ...
215
216
217@overload
218def args(*ss: str) -> Callable[[App], tuple[App, ...]]: ...
219
220
[docs]
221def args(*ids: Any) -> Callable[[App], tuple]:
222 def res(app: App) -> tuple[Pattern, ...]:
223 if not ids:
224 return ()
225
226 fst = ids[0]
227 if type(fst) is int:
228 return tuple(arg(n)(app) for n in ids)
229
230 symbol_match: dict[str, App] = {}
231 symbols = set(ids)
232
233 for _arg in app.args:
234 if type(_arg) is App and _arg.symbol in symbols and _arg.symbol not in symbol_match:
235 symbol_match[_arg.symbol] = _arg
236
237 if len(symbol_match) == len(symbols):
238 return tuple(symbol_match[symbol] for symbol in ids)
239
240 unmatched_symbols = symbols - set(symbol_match)
241 assert unmatched_symbols
242 unmatched_symbol_str = ', '.join(unmatched_symbols)
243 raise ValueError(f'No matching arguments found for symbols: {unmatched_symbol_str}')
244
245 return res
246
247
[docs]
248def inj(pattern: Pattern) -> Pattern:
249 return arg(0)(app('inj')(pattern))
250
251
[docs]
252def kore_list_of(item: Callable[[Pattern], T]) -> Callable[[Pattern], tuple[T, ...]]:
253 def res(pattern: Pattern) -> tuple[T, ...]:
254 return tuple(item(e) for e in match_list(pattern))
255
256 return res
257
258
[docs]
259def kore_set_of(item: Callable[[Pattern], T]) -> Callable[[Pattern], tuple[T, ...]]:
260 def res(pattern: Pattern) -> tuple[T, ...]:
261 return tuple(item(e) for e in match_set(pattern))
262
263 return res
264
265
[docs]
266def kore_map_of(
267 key: Callable[[Pattern], K],
268 value: Callable[[Pattern], V],
269 *,
270 cell: str | None = None,
271) -> Callable[[Pattern], tuple[tuple[K, V], ...]]:
272 def res(pattern: Pattern) -> tuple[tuple[K, V], ...]:
273 return tuple((key(k), value(v)) for k, v in match_map(pattern, cell=cell))
274
275 return res
276
277
[docs]
278def kore_rangemap_of(
279 key: Callable[[Pattern], K],
280 value: Callable[[Pattern], V],
281) -> Callable[[Pattern], tuple[tuple[tuple[K, K], V], ...]]:
282 def res(pattern: Pattern) -> tuple[tuple[tuple[K, K], V], ...]:
283 return tuple(((key(k[0]), key(k[1])), value(v)) for k, v in match_rangemap(pattern))
284
285 return res
286
287
[docs]
288def case_symbol(
289 *cases: tuple[str, Callable[[App], T]],
290 default: Callable[[App], T] | None = None,
291) -> Callable[[Pattern], T]:
292 def cond(symbol: str) -> Callable[[App], bool]:
293 return lambda app: app.symbol == symbol
294
295 def res(pattern: Pattern) -> T:
296 app = match_app(pattern)
297 return case(
298 cases=((cond(symbol), then) for symbol, then in cases),
299 default=default,
300 )(app)
301
302 return res