Source code for pyk.kore.match

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, overload
  4
  5from ..dequote import bytes_encode
  6from ..utils import case, check_type
  7from .prelude import BOOL, BYTES, ID, INT, STRING
  8from .syntax import DV, App, LeftAssoc
  9
 10if TYPE_CHECKING:
 11    from collections.abc import Callable
 12    from typing import Any, TypeVar
 13
 14    from .syntax import Pattern, Sort
 15
 16    T = TypeVar('T')
 17    K = TypeVar('K')
 18    V = TypeVar('V')
 19
 20
[docs] 21def match_dv(pattern: Pattern, sort: Sort | None = None) -> DV: 22 dv = check_type(pattern, DV) 23 if sort and dv.sort != sort: 24 raise ValueError(f'Expected sort {sort.text}, found: {dv.sort.text}') 25 return dv
26 27
[docs] 28def match_symbol(actual: str, expected: str) -> None: 29 if actual != expected: 30 raise ValueError(f'Expected symbol {expected}, found: {actual}')
31 32
[docs] 33def match_app(pattern: Pattern, symbol: str | None = None) -> App: 34 app = check_type(pattern, App) 35 if symbol is not None: 36 match_symbol(app.symbol, symbol) 37 return app
38 39
[docs] 40def match_inj(pattern: Pattern) -> App: 41 return match_app(pattern, 'inj')
42 43
[docs] 44def match_left_assoc(pattern: Pattern, symbol: str | None = None) -> LeftAssoc: 45 assoc = check_type(pattern, LeftAssoc) 46 if symbol is not None: 47 match_symbol(assoc.symbol, symbol) 48 return assoc
49 50
[docs] 51def match_list(pattern: Pattern) -> tuple[Pattern, ...]: 52 if type(pattern) is App: 53 match_app(pattern, "Lbl'Stop'List") 54 return () 55 56 assoc = match_left_assoc(pattern, "Lbl'Unds'List'Unds'") 57 items = (match_app(arg, 'LblListItem') for arg in assoc.args) 58 elems = (item.args[0] for item in items) 59 return tuple(elems)
60 61
[docs] 62def match_set(pattern: Pattern) -> tuple[Pattern, ...]: 63 if type(pattern) is App: 64 match_app(pattern, "Lbl'Stop'Set") 65 return () 66 67 assoc = match_left_assoc(pattern, "Lbl'Unds'Set'Unds'") 68 items = (match_app(arg, 'LblSetItem') for arg in assoc.args) 69 elems = (item.args[0] for item in items) 70 return tuple(elems)
71 72
[docs] 73def match_map(pattern: Pattern, *, cell: str | None = None) -> tuple[tuple[Pattern, Pattern], ...]: 74 cell = cell or '' 75 stop_symbol = f"Lbl'Stop'{cell}Map" 76 cons_symbol = f"Lbl'Unds'{cell}Map'Unds'" 77 item_symbol = "Lbl'UndsPipe'-'-GT-Unds'" if not cell else f'Lbl{cell}MapItem' 78 79 if type(pattern) is App: 80 match_app(pattern, stop_symbol) 81 return () 82 83 assoc = match_left_assoc(pattern, cons_symbol) 84 items = (match_app(arg, item_symbol) for arg in assoc.args) 85 entries = ((item.args[0], item.args[1]) for item in items) 86 return tuple(entries)
87 88
[docs] 89def match_rangemap(pattern: Pattern) -> tuple[tuple[tuple[Pattern, Pattern], Pattern], ...]: 90 stop_symbol = "Lbl'Stop'RangeMap" 91 cons_symbol = "Lbl'Unds'RangeMap'Unds'" 92 item_symbol = "Lbl'Unds'r'Pipe'-'-GT-Unds'" 93 94 if type(pattern) is App: 95 match_app(pattern, stop_symbol) 96 return () 97 98 assoc = match_left_assoc(pattern) 99 cons = match_app(assoc.app, cons_symbol) 100 items = (match_app(arg, item_symbol) for arg in cons.args) 101 entries = ((_match_range(item.args[0]), item.args[1]) for item in items) 102 return tuple(entries)
103 104 105def _match_range(pattern: Pattern) -> tuple[Pattern, Pattern]: 106 range_symbol = "LblRangeMap'Coln'Range" 107 range = match_app(pattern, range_symbol) 108 return (range.args[0], range.args[1]) 109 110
[docs] 111def kore_bool(pattern: Pattern) -> bool: 112 dv = match_dv(pattern, BOOL) 113 match dv.value.value: 114 case 'true': 115 return True 116 case 'false': 117 return False 118 case _: 119 raise ValueError(f'Invalid Boolean domain value: {dv.text}')
120 121
[docs] 122def kore_int(pattern: Pattern) -> int: 123 dv = match_dv(pattern, INT) 124 return int(dv.value.value)
125 126
[docs] 127def kore_bytes(pattern: Pattern) -> bytes: 128 dv = match_dv(pattern, BYTES) 129 return bytes_encode(dv.value.value)
130 131
[docs] 132def kore_str(pattern: Pattern) -> str: 133 dv = match_dv(pattern, STRING) 134 return dv.value.value
135 136
[docs] 137def kore_id(pattern: Pattern) -> str: 138 dv = match_dv(pattern, ID) 139 return dv.value.value
140 141 142# Higher-order functions 143 144
[docs] 145def app(symbol: str | None = None) -> Callable[[Pattern], App]: 146 def res(pattern: Pattern) -> App: 147 return match_app(pattern, symbol) 148 149 return res
150 151 152@overload 153def arg(n: int, /) -> Callable[[App], Pattern]: ... 154 155 156@overload 157def arg(symbol: str, /) -> Callable[[App], App]: ... 158 159
[docs] 160def arg(id: int | str) -> Callable[[App], Pattern | App]: 161 def res(app: App) -> Pattern | App: 162 if type(id) is int: 163 if len(app.args) <= id: 164 raise ValueError('Argument index is out of range') 165 166 return app.args[id] 167 168 try: 169 arg, *_ = (arg for arg in app.args if type(arg) is App and arg.symbol == id) 170 except ValueError: 171 raise ValueError(f'No matching argument found for symbol: {id}') from None 172 return arg 173 174 return res
175 176 177@overload 178def args() -> Callable[[App], tuple[()]]: ... 179 180 181@overload 182def args(n1: int, /) -> Callable[[App], tuple[Pattern]]: ... 183 184 185@overload 186def args(n1: int, n2: int, /) -> Callable[[App], tuple[Pattern, Pattern]]: ... 187 188 189@overload 190def args(n1: int, n2: int, n3: int, /) -> Callable[[App], tuple[Pattern, Pattern, Pattern]]: ... 191 192 193@overload 194def args(n1: int, n2: int, n3: int, n4: int, /) -> Callable[[App], tuple[Pattern, Pattern, Pattern, Pattern]]: ... 195 196 197@overload 198def args(*ns: int) -> Callable[[App], tuple[Pattern, ...]]: ... 199 200 201@overload 202def args(s1: str, /) -> Callable[[App], tuple[App]]: ... 203 204 205@overload 206def args(s1: str, s2: str, /) -> Callable[[App], tuple[App, App]]: ... 207 208 209@overload 210def args(s1: str, s2: str, s3: str, /) -> Callable[[App], tuple[App, App, App]]: ... 211 212 213@overload 214def args(s1: str, s2: str, s3: str, s4: str, /) -> Callable[[App], tuple[App, App, App, App]]: ... 215 216 217@overload 218def args(*ss: str) -> Callable[[App], tuple[App, ...]]: ... 219 220
[docs] 221def args(*ids: Any) -> Callable[[App], tuple]: 222 def res(app: App) -> tuple[Pattern, ...]: 223 if not ids: 224 return () 225 226 fst = ids[0] 227 if type(fst) is int: 228 return tuple(arg(n)(app) for n in ids) 229 230 symbol_match: dict[str, App] = {} 231 symbols = set(ids) 232 233 for _arg in app.args: 234 if type(_arg) is App and _arg.symbol in symbols and _arg.symbol not in symbol_match: 235 symbol_match[_arg.symbol] = _arg 236 237 if len(symbol_match) == len(symbols): 238 return tuple(symbol_match[symbol] for symbol in ids) 239 240 unmatched_symbols = symbols - set(symbol_match) 241 assert unmatched_symbols 242 unmatched_symbol_str = ', '.join(unmatched_symbols) 243 raise ValueError(f'No matching arguments found for symbols: {unmatched_symbol_str}') 244 245 return res
246 247
[docs] 248def inj(pattern: Pattern) -> Pattern: 249 return arg(0)(app('inj')(pattern))
250 251
[docs] 252def kore_list_of(item: Callable[[Pattern], T]) -> Callable[[Pattern], tuple[T, ...]]: 253 def res(pattern: Pattern) -> tuple[T, ...]: 254 return tuple(item(e) for e in match_list(pattern)) 255 256 return res
257 258
[docs] 259def kore_set_of(item: Callable[[Pattern], T]) -> Callable[[Pattern], tuple[T, ...]]: 260 def res(pattern: Pattern) -> tuple[T, ...]: 261 return tuple(item(e) for e in match_set(pattern)) 262 263 return res
264 265
[docs] 266def kore_map_of( 267 key: Callable[[Pattern], K], 268 value: Callable[[Pattern], V], 269 *, 270 cell: str | None = None, 271) -> Callable[[Pattern], tuple[tuple[K, V], ...]]: 272 def res(pattern: Pattern) -> tuple[tuple[K, V], ...]: 273 return tuple((key(k), value(v)) for k, v in match_map(pattern, cell=cell)) 274 275 return res
276 277
[docs] 278def kore_rangemap_of( 279 key: Callable[[Pattern], K], 280 value: Callable[[Pattern], V], 281) -> Callable[[Pattern], tuple[tuple[tuple[K, K], V], ...]]: 282 def res(pattern: Pattern) -> tuple[tuple[tuple[K, K], V], ...]: 283 return tuple(((key(k[0]), key(k[1])), value(v)) for k, v in match_rangemap(pattern)) 284 285 return res
286 287
[docs] 288def case_symbol( 289 *cases: tuple[str, Callable[[App], T]], 290 default: Callable[[App], T] | None = None, 291) -> Callable[[Pattern], T]: 292 def cond(symbol: str) -> Callable[[App], bool]: 293 return lambda app: app.symbol == symbol 294 295 def res(pattern: Pattern) -> T: 296 app = match_app(pattern) 297 return case( 298 cases=((cond(symbol), then) for symbol, then in cases), 299 default=default, 300 )(app) 301 302 return res