Source code for pyk.kast.pretty

  1from __future__ import annotations
  2
  3import logging
  4from collections.abc import Callable
  5from functools import cached_property
  6from typing import TYPE_CHECKING
  7
  8from ..prelude.kbool import TRUE
  9from .att import Atts, KAtt
 10from .inner import KApply, KAs, KInner, KLabel, KRewrite, KSequence, KSort, KToken, KVariable
 11from .manip import flatten_label, sort_ac_collections, undo_aliases
 12from .outer import (
 13    KBubble,
 14    KClaim,
 15    KContext,
 16    KDefinition,
 17    KFlatModule,
 18    KImport,
 19    KNonTerminal,
 20    KOuter,
 21    KProduction,
 22    KRegexTerminal,
 23    KRequire,
 24    KRule,
 25    KRuleLike,
 26    KSortSynonym,
 27    KSyntaxAssociativity,
 28    KSyntaxLexical,
 29    KSyntaxPriority,
 30    KSyntaxSort,
 31    KTerminal,
 32)
 33
 34if TYPE_CHECKING:
 35    from collections.abc import Iterable
 36    from typing import Any, Final, TypeVar
 37
 38    from .kast import KAst
 39
 40    RL = TypeVar('RL', bound='KRuleLike')
 41
 42_LOGGER: Final = logging.getLogger(__name__)
 43
 44SymbolTable = dict[str, Callable[..., str]]
 45
 46
[docs] 47class PrettyPrinter: 48 definition: KDefinition 49 _extra_unparsing_modules: Iterable[KFlatModule] 50 _patch_symbol_table: Callable[[SymbolTable], None] | None 51 _unalias: bool 52 _sort_collections: bool 53 54 def __init__( 55 self, 56 definition: KDefinition, 57 extra_unparsing_modules: Iterable[KFlatModule] = (), 58 patch_symbol_table: Callable[[SymbolTable], None] | None = None, 59 unalias: bool = True, 60 sort_collections: bool = False, 61 ): 62 self.definition = definition 63 self._extra_unparsing_modules = extra_unparsing_modules 64 self._patch_symbol_table = patch_symbol_table 65 self._unalias = unalias 66 self._sort_collections = sort_collections 67 68 @cached_property 69 def symbol_table(self) -> SymbolTable: 70 symb_table = build_symbol_table( 71 self.definition, 72 extra_modules=self._extra_unparsing_modules, 73 opinionated=True, 74 ) 75 if self._patch_symbol_table is not None: 76 self._patch_symbol_table(symb_table) 77 return symb_table 78
[docs] 79 def print(self, kast: KAst) -> str: 80 """Print out KAST terms/outer syntax. 81 82 Args: 83 kast: KAST term to print. 84 85 Returns: 86 Best-effort string representation of KAST term. 87 """ 88 _LOGGER.debug(f'Unparsing: {kast}') 89 if type(kast) is KAtt: 90 return self._print_katt(kast) 91 if type(kast) is KSort: 92 return self._print_ksort(kast) 93 if type(kast) is KLabel: 94 return self._print_klabel(kast) 95 elif isinstance(kast, KOuter): 96 return self._print_kouter(kast) 97 elif isinstance(kast, KInner): 98 if self._unalias: 99 kast = undo_aliases(self.definition, kast) 100 if self._sort_collections: 101 kast = sort_ac_collections(kast) 102 return self._print_kinner(kast) 103 raise AssertionError(f'Error unparsing: {kast}')
104 105 def _print_kouter(self, kast: KOuter) -> str: 106 match kast: 107 case KTerminal(): 108 return self._print_kterminal(kast) 109 case KRegexTerminal(): 110 return self._print_kregexterminal(kast) 111 case KNonTerminal(): 112 return self._print_knonterminal(kast) 113 case KProduction(): 114 return self._print_kproduction(kast) 115 case KSyntaxSort(): 116 return self._print_ksyntaxsort(kast) 117 case KSortSynonym(): 118 return self._print_ksortsynonym(kast) 119 case KSyntaxLexical(): 120 return self._print_ksyntaxlexical(kast) 121 case KSyntaxAssociativity(): 122 return self._print_ksyntaxassociativity(kast) 123 case KSyntaxPriority(): 124 return self._print_ksyntaxpriority(kast) 125 case KBubble(): 126 return self._print_kbubble(kast) 127 case KRule(): 128 return self._print_krule(kast) 129 case KClaim(): 130 return self._print_kclaim(kast) 131 case KContext(): 132 return self._print_kcontext(kast) 133 case KImport(): 134 return self._print_kimport(kast) 135 case KFlatModule(): 136 return self._print_kflatmodule(kast) 137 case KRequire(): 138 return self._print_krequire(kast) 139 case KDefinition(): 140 return self._print_kdefinition(kast) 141 case _: 142 raise AssertionError(f'Error unparsing: {kast}') 143 144 def _print_kinner(self, kast: KInner) -> str: 145 match kast: 146 case KVariable(): 147 return self._print_kvariable(kast) 148 case KToken(): 149 return self._print_ktoken(kast) 150 case KApply(): 151 return self._print_kapply(kast) 152 case KAs(): 153 return self._print_kas(kast) 154 case KRewrite(): 155 return self._print_krewrite(kast) 156 case KSequence(): 157 return self._print_ksequence(kast) 158 case _: 159 raise AssertionError(f'Error unparsing: {kast}') 160 161 def _print_ksort(self, ksort: KSort) -> str: 162 return ksort.name 163 164 def _print_klabel(self, klabel: KLabel) -> str: 165 return klabel.name 166 167 def _print_kvariable(self, kvariable: KVariable) -> str: 168 sort = kvariable.sort 169 if not sort: 170 return kvariable.name 171 return kvariable.name + ':' + sort.name 172 173 def _print_ktoken(self, ktoken: KToken) -> str: 174 return ktoken.token 175 176 def _print_kapply(self, kapply: KApply) -> str: 177 label = kapply.label.name 178 args = kapply.args 179 unparsed_args = [self._print_kinner(arg) for arg in args] 180 if kapply.is_cell: 181 cell_contents = '\n'.join(unparsed_args).rstrip() 182 cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:] 183 return cell_str.rstrip() 184 unparser = self._applied_label_str(label) if label not in self.symbol_table else self.symbol_table[label] 185 return unparser(*unparsed_args) 186 187 def _print_kas(self, kas: KAs) -> str: 188 pattern_str = self._print_kinner(kas.pattern) 189 alias_str = self._print_kinner(kas.alias) 190 return pattern_str + ' #as ' + alias_str 191 192 def _print_krewrite(self, krewrite: KRewrite) -> str: 193 lhs_str = self._print_kinner(krewrite.lhs) 194 rhs_str = self._print_kinner(krewrite.rhs) 195 return '( ' + lhs_str + ' => ' + rhs_str + ' )' 196 197 def _print_ksequence(self, ksequence: KSequence) -> str: 198 if ksequence.arity == 0: 199 # TODO: Would be nice to say `return self._print_kinner(EMPTY_K)` 200 return '.K' 201 if ksequence.arity == 1: 202 return self._print_kinner(ksequence.items[0]) + ' ~> .K' 203 unparsed_k_seq = '\n~> '.join([self._print_kinner(item) for item in ksequence.items[0:-1]]) 204 if ksequence.items[-1] == KToken('...', KSort('K')): 205 unparsed_k_seq = unparsed_k_seq + '\n' + self._print_kinner(KToken('...', KSort('K'))) 206 else: 207 unparsed_k_seq = unparsed_k_seq + '\n~> ' + self._print_kinner(ksequence.items[-1]) 208 return unparsed_k_seq 209 210 def _print_kterminal(self, kterminal: KTerminal) -> str: 211 return '"' + kterminal.value + '"' 212 213 def _print_kregexterminal(self, kregexterminal: KRegexTerminal) -> str: 214 return 'r"' + kregexterminal.regex + '"' 215 216 def _print_knonterminal(self, knonterminal: KNonTerminal) -> str: 217 return self.print(knonterminal.sort) 218 219 def _print_kproduction(self, kproduction: KProduction) -> str: 220 syntax_str = 'syntax ' + self.print(kproduction.sort) 221 if kproduction.items: 222 syntax_str += ' ::= ' + ' '.join([self._print_kouter(pi) for pi in kproduction.items]) 223 att_str = self.print(kproduction.att) 224 if att_str: 225 syntax_str += ' ' + att_str 226 return syntax_str 227 228 def _print_ksyntaxsort(self, ksyntaxsort: KSyntaxSort) -> str: 229 sort_str = self.print(ksyntaxsort.sort) 230 att_str = self.print(ksyntaxsort.att) 231 return 'syntax ' + sort_str + ' ' + att_str 232 233 def _print_ksortsynonym(self, ksortsynonym: KSortSynonym) -> str: 234 new_sort_str = self.print(ksortsynonym.new_sort) 235 old_sort_str = self.print(ksortsynonym.old_sort) 236 att_str = self.print(ksortsynonym.att) 237 return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str 238 239 def _print_ksyntaxlexical(self, ksyntaxlexical: KSyntaxLexical) -> str: 240 name_str = ksyntaxlexical.name 241 regex_str = ksyntaxlexical.regex 242 att_str = self.print(ksyntaxlexical.att) 243 # todo: proper escaping 244 return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str 245 246 def _print_ksyntaxassociativity(self, ksyntaxassociativity: KSyntaxAssociativity) -> str: 247 assoc_str = ksyntaxassociativity.assoc.value 248 tags_str = ' '.join(ksyntaxassociativity.tags) 249 att_str = self.print(ksyntaxassociativity.att) 250 return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str 251 252 def _print_ksyntaxpriority(self, ksyntaxpriority: KSyntaxPriority) -> str: 253 priorities_str = ' > '.join([' '.join(group) for group in ksyntaxpriority.priorities]) 254 att_str = self.print(ksyntaxpriority.att) 255 return 'syntax priority ' + priorities_str + ' ' + att_str 256 257 def _print_kbubble(self, kbubble: KBubble) -> str: 258 body = '// KBubble(' + kbubble.sentence_type + ', ' + kbubble.contents + ')' 259 att_str = self.print(kbubble.att) 260 return body + ' ' + att_str 261 262 def _print_krule(self, kterm: KRule) -> str: 263 body = '\n '.join(self.print(kterm.body).split('\n')) 264 rule_str = 'rule ' 265 if Atts.LABEL in kterm.att: 266 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:' 267 rule_str = rule_str + ' ' + body 268 atts_str = self.print(kterm.att) 269 if kterm.requires != TRUE: 270 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n')) 271 rule_str = rule_str + '\n ' + requires_str 272 if kterm.ensures != TRUE: 273 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n')) 274 rule_str = rule_str + '\n ' + ensures_str 275 return rule_str + '\n ' + atts_str 276 277 def _print_kclaim(self, kterm: KClaim) -> str: 278 body = '\n '.join(self.print(kterm.body).split('\n')) 279 rule_str = 'claim ' 280 if Atts.LABEL in kterm.att: 281 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:' 282 rule_str = rule_str + ' ' + body 283 atts_str = self.print(kterm.att) 284 if kterm.requires != TRUE: 285 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n')) 286 rule_str = rule_str + '\n ' + requires_str 287 if kterm.ensures != TRUE: 288 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n')) 289 rule_str = rule_str + '\n ' + ensures_str 290 return rule_str + '\n ' + atts_str 291 292 def _print_kcontext(self, kcontext: KContext) -> str: 293 body = indent(self.print(kcontext.body)) 294 context_str = 'context alias ' + body 295 requires_str = '' 296 atts_str = self.print(kcontext.att) 297 if kcontext.requires != TRUE: 298 requires_str = self.print(kcontext.requires) 299 requires_str = 'requires ' + indent(requires_str) 300 return context_str + '\n ' + requires_str + '\n ' + atts_str 301 302 def _print_katt(self, katt: KAtt) -> str: 303 return katt.pretty 304 305 def _print_kimport(self, kimport: KImport) -> str: 306 return ' '.join(['imports', ('public' if kimport.public else 'private'), kimport.name]) 307 308 def _print_kflatmodule(self, kflatmodule: KFlatModule) -> str: 309 name = kflatmodule.name 310 imports = '\n'.join([self._print_kouter(kimport) for kimport in kflatmodule.imports]) 311 sentences = '\n\n'.join([self._print_kouter(sentence) for sentence in kflatmodule.sentences]) 312 contents = imports + '\n\n' + sentences 313 return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule' 314 315 def _print_krequire(self, krequire: KRequire) -> str: 316 return 'requires "' + krequire.require + '"' 317 318 def _print_kdefinition(self, kdefinition: KDefinition) -> str: 319 requires = '\n'.join([self._print_kouter(require) for require in kdefinition.requires]) 320 modules = '\n\n'.join([self._print_kouter(module) for module in kdefinition.all_modules]) 321 return requires + '\n\n' + modules 322 323 def _print_kast_bool(self, kast: KAst) -> str: 324 """Print out KAST requires/ensures clause. 325 326 Args: 327 kast: KAST Bool for requires/ensures clause. 328 329 Returns: 330 Best-effort string representation of KAST term. 331 """ 332 _LOGGER.debug(f'_print_kast_bool: {kast}') 333 if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']: 334 clauses = [self._print_kast_bool(c) for c in flatten_label(kast.label.name, kast)] 335 head = kast.label.name.replace('_', ' ') 336 if head == ' orBool ': 337 head = ' orBool ' 338 separator = ' ' * (len(head) - 7) 339 spacer = ' ' * len(head) 340 341 def join_sep(s: str) -> str: 342 return ('\n' + separator).join(s.split('\n')) 343 344 clauses = ( 345 ['( ' + join_sep(clauses[0])] 346 + [head + '( ' + join_sep(c) for c in clauses[1:]] 347 + [spacer + (')' * len(clauses))] 348 ) 349 return '\n'.join(clauses) 350 else: 351 return self.print(kast) 352 353 def _applied_label_str(self, symbol: str) -> Callable[..., str]: 354 return lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )'
355 356
[docs] 357def build_symbol_table( 358 definition: KDefinition, 359 extra_modules: Iterable[KFlatModule] = (), 360 opinionated: bool = False, 361) -> SymbolTable: 362 """Build the unparsing symbol table given a JSON encoded definition. 363 364 Args: 365 definition: JSON encoded K definition. 366 367 Returns: 368 Python dictionary mapping klabels to automatically generated unparsers. 369 """ 370 symbol_table = {} 371 all_modules = list(definition.all_modules) + ([] if extra_modules is None else list(extra_modules)) 372 for module in all_modules: 373 for prod in module.syntax_productions: 374 assert prod.klabel 375 label = prod.klabel.name 376 unparser = unparser_for_production(prod) 377 378 symbol_table[label] = unparser 379 if Atts.SYMBOL in prod.att: 380 symbol_table[prod.att[Atts.SYMBOL]] = unparser 381 382 if opinionated: 383 symbol_table['#And'] = lambda c1, c2: c1 + '\n#And ' + c2 384 symbol_table['#Or'] = lambda c1, c2: c1 + '\n#Or\n' + indent(c2, size=4) 385 386 return symbol_table
387 388
[docs] 389def unparser_for_production(prod: KProduction) -> Callable[..., str]: 390 def _unparser(*args: Any) -> str: 391 index = 0 392 result = [] 393 num_nonterm = len([item for item in prod.items if type(item) is KNonTerminal]) 394 num_named_nonterm = len([item for item in prod.items if type(item) is KNonTerminal and item.name != None]) 395 for item in prod.items: 396 if type(item) is KTerminal: 397 result.append(item.value) 398 elif type(item) is KNonTerminal and index < len(args): 399 if num_nonterm == num_named_nonterm: 400 if index == 0: 401 result.append('...') 402 result.append(f'{item.name}:') 403 result.append(args[index]) 404 index += 1 405 return ' '.join(result) 406 407 return _unparser
408 409
[docs] 410def indent(text: str, size: int = 2) -> str: 411 return '\n'.join([(' ' * size) + line for line in text.split('\n')])
412 413
[docs] 414def paren(printer: Callable[..., str]) -> Callable[..., str]: 415 return lambda *args: '( ' + printer(*args) + ' )'
416 417
[docs] 418def assoc_with_unit(assoc_join: str, unit: str) -> Callable[..., str]: 419 def _assoc_with_unit(*args: str) -> str: 420 return assoc_join.join(arg for arg in args if arg != unit) 421 422 return _assoc_with_unit