Skip to content
This repository was archived by the owner on Apr 25, 2024. It is now read-only.

Commit 04c961d

Browse files
committed
ktool/kprove: refactor how claim labels are processed
1 parent 00cf6b7 commit 04c961d

File tree

1 file changed

+38
-44
lines changed

1 file changed

+38
-44
lines changed

src/pyk/ktool/kprove.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..kast import kast_term
1717
from ..kast.inner import KInner
1818
from ..kast.manip import extract_subst, flatten_label, free_vars
19-
from ..kast.outer import KClaim, KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire
19+
from ..kast.outer import KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire
2020
from ..prelude.ml import is_top, mlAnd
2121
from ..utils import gen_file_timestamp, run_process, unique
2222
from . import TypeInferenceMode
@@ -27,7 +27,7 @@
2727
from subprocess import CompletedProcess
2828
from typing import Final
2929

30-
from ..kast.outer import KRule, KRuleLike, KSentence
30+
from ..kast.outer import KClaim, KRule, KRuleLike
3131
from ..kast.pretty import SymbolTable
3232
from ..utils import BugReport
3333

@@ -337,27 +337,7 @@ def get_claim_modules(
337337
dry_run=True,
338338
args=['--emit-json-spec', ntf.name],
339339
)
340-
flat_module_list = KFlatModuleList.from_dict(kast_term(json.loads(Path(ntf.name).read_text())))
341-
342-
def _qualify_name(_module_name: str, _claim_label: str) -> str:
343-
return _claim_label if _claim_label.startswith(_module_name) else f'{_module_name}.{_claim_label}'
344-
345-
modules = []
346-
for module in flat_module_list.modules:
347-
sentences: list[KSentence] = []
348-
for sentence in module.sentences:
349-
if type(sentence) is KClaim:
350-
_label = _qualify_name(module.name, sentence.label)
351-
_att = sentence.att.update({'label': _label})
352-
if len(sentence.dependencies) > 0:
353-
_dependencies = [_qualify_name(module.name, dep) for dep in sentence.dependencies]
354-
_att = _att.update({'depends': ','.join(_dependencies)})
355-
sentences.append(sentence.let(att=_att))
356-
else:
357-
sentences.append(sentence)
358-
modules.append(module.let(sentences=sentences))
359-
360-
return flat_module_list.let(modules=modules)
340+
return KFlatModuleList.from_dict(kast_term(json.loads(Path(ntf.name).read_text())))
361341

362342
def get_claims(
363343
self,
@@ -375,35 +355,49 @@ def get_claims(
375355
include_dirs=include_dirs,
376356
md_selector=md_selector,
377357
)
378-
module_names = [module.name for module in flat_module_list.modules]
379358

380-
def _qualify_label(_label: str) -> str:
381-
if any(_label.startswith(mname) for mname in module_names):
382-
return _label
383-
return f'{flat_module_list.main_module}.{_label}'
359+
def _qualify_label(_module_name: str, _label: str) -> str:
360+
return _label if _label.startswith(_module_name) else f'{_module_name}.{_label}'
384361

385-
all_claims = {c.label: c for m in flat_module_list.modules for c in m.claims}
386-
exclude_claim_labels = (
387-
[] if exclude_claim_labels is None else [_qualify_label(cl) for cl in exclude_claim_labels]
388-
)
389-
claim_labels = list(all_claims.keys()) if claim_labels is None else [_qualify_label(cl) for cl in claim_labels]
390-
unfound_labels: list[str] = [cl for cl in claim_labels + exclude_claim_labels if cl not in all_claims]
362+
all_claims = {
363+
f'{module.name}.{claim.label}': (claim, module.name)
364+
for module in flat_module_list.modules
365+
for claim in module.claims
366+
}
367+
368+
claim_labels = list(all_claims.keys()) if claim_labels is None else list(claim_labels)
369+
exclude_claim_labels = [] if exclude_claim_labels is None else list(exclude_claim_labels)
370+
371+
unfound_labels: list[str] = [
372+
cl
373+
for cl in list(claim_labels) + list(exclude_claim_labels)
374+
if cl not in all_claims and f'{flat_module_list.main_module}.{cl}' not in all_claims
375+
]
391376
if len(unfound_labels) > 0:
392377
raise ValueError(f'Claim labels not found: {unfound_labels}')
393378

394-
final_claim_labels: list[str] = []
379+
final_claims: dict[str, KClaim] = {}
380+
unfound_dependencies: list[str] = []
395381
while len(claim_labels) > 0:
396382
claim_label = claim_labels.pop(0)
397-
if claim_label in final_claim_labels:
383+
if claim_label in final_claims:
398384
continue
399-
elif claim_label in exclude_claim_labels:
400-
_LOGGER.warning(f'Including claim that is also excluded: {claim_label}')
401-
else:
402-
final_claim_labels.append(claim_label)
403-
if include_dependencies:
404-
claim_labels.extend(all_claims[claim_label].dependencies)
405-
406-
return [all_claims[cl] for cl in final_claim_labels]
385+
if claim_label not in all_claims:
386+
claim_label = f'{flat_module_list.main_module}.{claim_label}'
387+
_claim, _module_name = all_claims[claim_label]
388+
final_claims[claim_label] = _claim
389+
for _dependency_label in _claim.dependencies:
390+
if _dependency_label not in all_claims:
391+
if f'{_module_name}.{_dependency_label}' not in all_claims:
392+
unfound_dependencies.append(_dependency_label)
393+
continue
394+
_dependency_label = f'{_module_name}.{_dependency_label}'
395+
claim_labels.append(_dependency_label)
396+
397+
if len(unfound_dependencies) > 0:
398+
raise ValueError(f'Dependency claim labels not found: {unfound_dependencies}')
399+
400+
return list(final_claims.values())
407401

408402
@contextmanager
409403
def _tmp_claim_definition(

0 commit comments

Comments
 (0)