1from __future__ import annotations
2
3import logging
4from collections.abc import Callable
5from functools import cached_property
6from typing import TYPE_CHECKING
7
8from ..prelude.kbool import TRUE
9from .att import Atts, KAtt
10from .inner import KApply, KAs, KInner, KLabel, KRewrite, KSequence, KSort, KToken, KVariable
11from .manip import flatten_label, sort_ac_collections, undo_aliases
12from .outer import (
13 KBubble,
14 KClaim,
15 KContext,
16 KDefinition,
17 KFlatModule,
18 KImport,
19 KNonTerminal,
20 KOuter,
21 KProduction,
22 KRegexTerminal,
23 KRequire,
24 KRule,
25 KRuleLike,
26 KSortSynonym,
27 KSyntaxAssociativity,
28 KSyntaxLexical,
29 KSyntaxPriority,
30 KSyntaxSort,
31 KTerminal,
32)
33
34if TYPE_CHECKING:
35 from collections.abc import Iterable
36 from typing import Any, Final, TypeVar
37
38 from .kast import KAst
39
40 RL = TypeVar('RL', bound='KRuleLike')
41
42_LOGGER: Final = logging.getLogger(__name__)
43
44SymbolTable = dict[str, Callable[..., str]]
45
46
[docs]
47class PrettyPrinter:
48 definition: KDefinition
49 _extra_unparsing_modules: Iterable[KFlatModule]
50 _patch_symbol_table: Callable[[SymbolTable], None] | None
51 _unalias: bool
52 _sort_collections: bool
53
54 def __init__(
55 self,
56 definition: KDefinition,
57 extra_unparsing_modules: Iterable[KFlatModule] = (),
58 patch_symbol_table: Callable[[SymbolTable], None] | None = None,
59 unalias: bool = True,
60 sort_collections: bool = False,
61 ):
62 self.definition = definition
63 self._extra_unparsing_modules = extra_unparsing_modules
64 self._patch_symbol_table = patch_symbol_table
65 self._unalias = unalias
66 self._sort_collections = sort_collections
67
68 @cached_property
69 def symbol_table(self) -> SymbolTable:
70 symb_table = build_symbol_table(
71 self.definition,
72 extra_modules=self._extra_unparsing_modules,
73 opinionated=True,
74 )
75 if self._patch_symbol_table is not None:
76 self._patch_symbol_table(symb_table)
77 return symb_table
78
[docs]
79 def print(self, kast: KAst) -> str:
80 """Print out KAST terms/outer syntax.
81
82 Args:
83 kast: KAST term to print.
84
85 Returns:
86 Best-effort string representation of KAST term.
87 """
88 _LOGGER.debug(f'Unparsing: {kast}')
89 if type(kast) is KAtt:
90 return self._print_katt(kast)
91 if type(kast) is KSort:
92 return self._print_ksort(kast)
93 if type(kast) is KLabel:
94 return self._print_klabel(kast)
95 elif isinstance(kast, KOuter):
96 return self._print_kouter(kast)
97 elif isinstance(kast, KInner):
98 if self._unalias:
99 kast = undo_aliases(self.definition, kast)
100 if self._sort_collections:
101 kast = sort_ac_collections(kast)
102 return self._print_kinner(kast)
103 raise AssertionError(f'Error unparsing: {kast}')
104
105 def _print_kouter(self, kast: KOuter) -> str:
106 match kast:
107 case KTerminal():
108 return self._print_kterminal(kast)
109 case KRegexTerminal():
110 return self._print_kregexterminal(kast)
111 case KNonTerminal():
112 return self._print_knonterminal(kast)
113 case KProduction():
114 return self._print_kproduction(kast)
115 case KSyntaxSort():
116 return self._print_ksyntaxsort(kast)
117 case KSortSynonym():
118 return self._print_ksortsynonym(kast)
119 case KSyntaxLexical():
120 return self._print_ksyntaxlexical(kast)
121 case KSyntaxAssociativity():
122 return self._print_ksyntaxassociativity(kast)
123 case KSyntaxPriority():
124 return self._print_ksyntaxpriority(kast)
125 case KBubble():
126 return self._print_kbubble(kast)
127 case KRule():
128 return self._print_krule(kast)
129 case KClaim():
130 return self._print_kclaim(kast)
131 case KContext():
132 return self._print_kcontext(kast)
133 case KImport():
134 return self._print_kimport(kast)
135 case KFlatModule():
136 return self._print_kflatmodule(kast)
137 case KRequire():
138 return self._print_krequire(kast)
139 case KDefinition():
140 return self._print_kdefinition(kast)
141 case _:
142 raise AssertionError(f'Error unparsing: {kast}')
143
144 def _print_kinner(self, kast: KInner) -> str:
145 match kast:
146 case KVariable():
147 return self._print_kvariable(kast)
148 case KToken():
149 return self._print_ktoken(kast)
150 case KApply():
151 return self._print_kapply(kast)
152 case KAs():
153 return self._print_kas(kast)
154 case KRewrite():
155 return self._print_krewrite(kast)
156 case KSequence():
157 return self._print_ksequence(kast)
158 case _:
159 raise AssertionError(f'Error unparsing: {kast}')
160
161 def _print_ksort(self, ksort: KSort) -> str:
162 return ksort.name
163
164 def _print_klabel(self, klabel: KLabel) -> str:
165 return klabel.name
166
167 def _print_kvariable(self, kvariable: KVariable) -> str:
168 sort = kvariable.sort
169 if not sort:
170 return kvariable.name
171 return kvariable.name + ':' + sort.name
172
173 def _print_ktoken(self, ktoken: KToken) -> str:
174 return ktoken.token
175
176 def _print_kapply(self, kapply: KApply) -> str:
177 label = kapply.label.name
178 args = kapply.args
179 unparsed_args = [self._print_kinner(arg) for arg in args]
180 if kapply.is_cell:
181 cell_contents = '\n'.join(unparsed_args).rstrip()
182 cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:]
183 return cell_str.rstrip()
184 unparser = self._applied_label_str(label) if label not in self.symbol_table else self.symbol_table[label]
185 return unparser(*unparsed_args)
186
187 def _print_kas(self, kas: KAs) -> str:
188 pattern_str = self._print_kinner(kas.pattern)
189 alias_str = self._print_kinner(kas.alias)
190 return pattern_str + ' #as ' + alias_str
191
192 def _print_krewrite(self, krewrite: KRewrite) -> str:
193 lhs_str = self._print_kinner(krewrite.lhs)
194 rhs_str = self._print_kinner(krewrite.rhs)
195 return '( ' + lhs_str + ' => ' + rhs_str + ' )'
196
197 def _print_ksequence(self, ksequence: KSequence) -> str:
198 if ksequence.arity == 0:
199 # TODO: Would be nice to say `return self._print_kinner(EMPTY_K)`
200 return '.K'
201 if ksequence.arity == 1:
202 return self._print_kinner(ksequence.items[0]) + ' ~> .K'
203 unparsed_k_seq = '\n~> '.join([self._print_kinner(item) for item in ksequence.items[0:-1]])
204 if ksequence.items[-1] == KToken('...', KSort('K')):
205 unparsed_k_seq = unparsed_k_seq + '\n' + self._print_kinner(KToken('...', KSort('K')))
206 else:
207 unparsed_k_seq = unparsed_k_seq + '\n~> ' + self._print_kinner(ksequence.items[-1])
208 return unparsed_k_seq
209
210 def _print_kterminal(self, kterminal: KTerminal) -> str:
211 return '"' + kterminal.value + '"'
212
213 def _print_kregexterminal(self, kregexterminal: KRegexTerminal) -> str:
214 return 'r"' + kregexterminal.regex + '"'
215
216 def _print_knonterminal(self, knonterminal: KNonTerminal) -> str:
217 return self.print(knonterminal.sort)
218
219 def _print_kproduction(self, kproduction: KProduction) -> str:
220 syntax_str = 'syntax ' + self.print(kproduction.sort)
221 if kproduction.items:
222 syntax_str += ' ::= ' + ' '.join([self._print_kouter(pi) for pi in kproduction.items])
223 att_str = self.print(kproduction.att)
224 if att_str:
225 syntax_str += ' ' + att_str
226 return syntax_str
227
228 def _print_ksyntaxsort(self, ksyntaxsort: KSyntaxSort) -> str:
229 sort_str = self.print(ksyntaxsort.sort)
230 att_str = self.print(ksyntaxsort.att)
231 return 'syntax ' + sort_str + ' ' + att_str
232
233 def _print_ksortsynonym(self, ksortsynonym: KSortSynonym) -> str:
234 new_sort_str = self.print(ksortsynonym.new_sort)
235 old_sort_str = self.print(ksortsynonym.old_sort)
236 att_str = self.print(ksortsynonym.att)
237 return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str
238
239 def _print_ksyntaxlexical(self, ksyntaxlexical: KSyntaxLexical) -> str:
240 name_str = ksyntaxlexical.name
241 regex_str = ksyntaxlexical.regex
242 att_str = self.print(ksyntaxlexical.att)
243 # todo: proper escaping
244 return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str
245
246 def _print_ksyntaxassociativity(self, ksyntaxassociativity: KSyntaxAssociativity) -> str:
247 assoc_str = ksyntaxassociativity.assoc.value
248 tags_str = ' '.join(ksyntaxassociativity.tags)
249 att_str = self.print(ksyntaxassociativity.att)
250 return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str
251
252 def _print_ksyntaxpriority(self, ksyntaxpriority: KSyntaxPriority) -> str:
253 priorities_str = ' > '.join([' '.join(group) for group in ksyntaxpriority.priorities])
254 att_str = self.print(ksyntaxpriority.att)
255 return 'syntax priority ' + priorities_str + ' ' + att_str
256
257 def _print_kbubble(self, kbubble: KBubble) -> str:
258 body = '// KBubble(' + kbubble.sentence_type + ', ' + kbubble.contents + ')'
259 att_str = self.print(kbubble.att)
260 return body + ' ' + att_str
261
262 def _print_krule(self, kterm: KRule) -> str:
263 body = '\n '.join(self.print(kterm.body).split('\n'))
264 rule_str = 'rule '
265 if Atts.LABEL in kterm.att:
266 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:'
267 rule_str = rule_str + ' ' + body
268 atts_str = self.print(kterm.att)
269 if kterm.requires != TRUE:
270 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n'))
271 rule_str = rule_str + '\n ' + requires_str
272 if kterm.ensures != TRUE:
273 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n'))
274 rule_str = rule_str + '\n ' + ensures_str
275 return rule_str + '\n ' + atts_str
276
277 def _print_kclaim(self, kterm: KClaim) -> str:
278 body = '\n '.join(self.print(kterm.body).split('\n'))
279 rule_str = 'claim '
280 if Atts.LABEL in kterm.att:
281 rule_str = rule_str + '[' + kterm.att[Atts.LABEL] + ']:'
282 rule_str = rule_str + ' ' + body
283 atts_str = self.print(kterm.att)
284 if kterm.requires != TRUE:
285 requires_str = 'requires ' + '\n '.join(self._print_kast_bool(kterm.requires).split('\n'))
286 rule_str = rule_str + '\n ' + requires_str
287 if kterm.ensures != TRUE:
288 ensures_str = 'ensures ' + '\n '.join(self._print_kast_bool(kterm.ensures).split('\n'))
289 rule_str = rule_str + '\n ' + ensures_str
290 return rule_str + '\n ' + atts_str
291
292 def _print_kcontext(self, kcontext: KContext) -> str:
293 body = indent(self.print(kcontext.body))
294 context_str = 'context alias ' + body
295 requires_str = ''
296 atts_str = self.print(kcontext.att)
297 if kcontext.requires != TRUE:
298 requires_str = self.print(kcontext.requires)
299 requires_str = 'requires ' + indent(requires_str)
300 return context_str + '\n ' + requires_str + '\n ' + atts_str
301
302 def _print_katt(self, katt: KAtt) -> str:
303 return katt.pretty
304
305 def _print_kimport(self, kimport: KImport) -> str:
306 return ' '.join(['imports', ('public' if kimport.public else 'private'), kimport.name])
307
308 def _print_kflatmodule(self, kflatmodule: KFlatModule) -> str:
309 name = kflatmodule.name
310 imports = '\n'.join([self._print_kouter(kimport) for kimport in kflatmodule.imports])
311 sentences = '\n\n'.join([self._print_kouter(sentence) for sentence in kflatmodule.sentences])
312 contents = imports + '\n\n' + sentences
313 return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule'
314
315 def _print_krequire(self, krequire: KRequire) -> str:
316 return 'requires "' + krequire.require + '"'
317
318 def _print_kdefinition(self, kdefinition: KDefinition) -> str:
319 requires = '\n'.join([self._print_kouter(require) for require in kdefinition.requires])
320 modules = '\n\n'.join([self._print_kouter(module) for module in kdefinition.all_modules])
321 return requires + '\n\n' + modules
322
323 def _print_kast_bool(self, kast: KAst) -> str:
324 """Print out KAST requires/ensures clause.
325
326 Args:
327 kast: KAST Bool for requires/ensures clause.
328
329 Returns:
330 Best-effort string representation of KAST term.
331 """
332 _LOGGER.debug(f'_print_kast_bool: {kast}')
333 if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']:
334 clauses = [self._print_kast_bool(c) for c in flatten_label(kast.label.name, kast)]
335 head = kast.label.name.replace('_', ' ')
336 if head == ' orBool ':
337 head = ' orBool '
338 separator = ' ' * (len(head) - 7)
339 spacer = ' ' * len(head)
340
341 def join_sep(s: str) -> str:
342 return ('\n' + separator).join(s.split('\n'))
343
344 clauses = (
345 ['( ' + join_sep(clauses[0])]
346 + [head + '( ' + join_sep(c) for c in clauses[1:]]
347 + [spacer + (')' * len(clauses))]
348 )
349 return '\n'.join(clauses)
350 else:
351 return self.print(kast)
352
353 def _applied_label_str(self, symbol: str) -> Callable[..., str]:
354 return lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )'
355
356
[docs]
357def build_symbol_table(
358 definition: KDefinition,
359 extra_modules: Iterable[KFlatModule] = (),
360 opinionated: bool = False,
361) -> SymbolTable:
362 """Build the unparsing symbol table given a JSON encoded definition.
363
364 Args:
365 definition: JSON encoded K definition.
366
367 Returns:
368 Python dictionary mapping klabels to automatically generated unparsers.
369 """
370 symbol_table = {}
371 all_modules = list(definition.all_modules) + ([] if extra_modules is None else list(extra_modules))
372 for module in all_modules:
373 for prod in module.syntax_productions:
374 assert prod.klabel
375 label = prod.klabel.name
376 unparser = unparser_for_production(prod)
377
378 symbol_table[label] = unparser
379 if Atts.SYMBOL in prod.att:
380 symbol_table[prod.att[Atts.SYMBOL]] = unparser
381
382 if opinionated:
383 symbol_table['#And'] = lambda c1, c2: c1 + '\n#And ' + c2
384 symbol_table['#Or'] = lambda c1, c2: c1 + '\n#Or\n' + indent(c2, size=4)
385
386 return symbol_table
387
388
[docs]
389def unparser_for_production(prod: KProduction) -> Callable[..., str]:
390 def _unparser(*args: Any) -> str:
391 index = 0
392 result = []
393 num_nonterm = len([item for item in prod.items if type(item) is KNonTerminal])
394 num_named_nonterm = len([item for item in prod.items if type(item) is KNonTerminal and item.name != None])
395 for item in prod.items:
396 if type(item) is KTerminal:
397 result.append(item.value)
398 elif type(item) is KNonTerminal and index < len(args):
399 if num_nonterm == num_named_nonterm:
400 if index == 0:
401 result.append('...')
402 result.append(f'{item.name}:')
403 result.append(args[index])
404 index += 1
405 return ' '.join(result)
406
407 return _unparser
408
409
[docs]
410def indent(text: str, size: int = 2) -> str:
411 return '\n'.join([(' ' * size) + line for line in text.split('\n')])
412
413
[docs]
414def paren(printer: Callable[..., str]) -> Callable[..., str]:
415 return lambda *args: '( ' + printer(*args) + ' )'
416
417
[docs]
418def assoc_with_unit(assoc_join: str, unit: str) -> Callable[..., str]:
419 def _assoc_with_unit(*args: str) -> str:
420 return assoc_join.join(arg for arg in args if arg != unit)
421
422 return _assoc_with_unit