Skip to content

Commit 9b9ef01

Browse files
arthaudmeta-codesync[bot]
authored andcommitted
Introduce ModuleAnswersContext
Summary: # Context This stack will re-design how `pyrefly --report-pysa` is implemented. Right now, it is a post processing step that requires pyrefly to preserve internal states for all modules, which can lead to high memory usage. To avoid this, we will incorporate the pysa reporting steps within the type checking. See https://docs.google.com/document/d/1Bk8izFv4nQxbQmaog9OqcpcSxyzAW3Z-bn0JUV4wkg4/ for context # This diff In the next diffs, we will start building pysa information while type checking is running. To do this properly, we need to be integrated within the different resolution steps (ast, bindings, answers, solutions, etc.) of pyrefly. The obvious approach so to build pysa information once solutions for a module are built, since that's when we have all the information about a module. To do this properly, we should ensure that our pysa information building step avoids querying for cross-module information, since that creates cycles during the resolution. One solution is to avoid accessing the `transaction` when building this information. To do this, we introduce a `ModuleAnswersContext` struct from `ModuleContext`, containing the per-module data that doesn't require cross-module access (handle, module_id, module_info, stdlib, ast, bindings, answers). `ModuleContext` now wraps ModuleAnswersContext alongside `module_ids` and `transaction` which allow cross-module information access. This diff also inlines `get_class_field_from_current_class_only` to avoid `ad_hoc_solve` for current-module class field lookups, narrowing it to take `ModuleAnswersContext` instead of `ModuleContext`. Reviewed By: tianhan0 Differential Revision: D97328329 fbshipit-source-id: 02e70f6967f21a6d4cc92a671adf434759e63ced
1 parent 1d2c01d commit 9b9ef01

21 files changed

+907
-467
lines changed

pyrefly/lib/report/pysa/ast_visitor.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,20 +307,24 @@ fn visit_statement<V: AstScopedVisitor>(
307307
Stmt::FunctionDef(function_def) => {
308308
let key = KeyDecoratedFunction(ShortIdentifier::new(&function_def.name));
309309
let function_scope = if let Some(idx) = module_context
310+
.answers_context
310311
.bindings
311312
.key_to_idx_hashed_opt(Hashed::new(&key))
312313
{
313314
let decorated_function = DecoratedFunction::from_bindings_answers(
314315
idx,
315-
&module_context.bindings,
316-
&module_context.answers,
316+
&module_context.answers_context.bindings,
317+
&module_context.answers_context.answers,
317318
);
318-
if should_export_decorated_function(&decorated_function, module_context) {
319+
if should_export_decorated_function(
320+
&decorated_function,
321+
&module_context.answers_context,
322+
) {
319323
Scope::ExportedFunction {
320324
function_id: FunctionId::Function {
321325
location: PysaLocation::from_text_range(
322326
function_def.identifier().range(),
323-
&module_context.module_info,
327+
&module_context.answers_context.module_info,
324328
),
325329
},
326330
location: function_def.identifier().range(),
@@ -424,10 +428,12 @@ fn visit_statement<V: AstScopedVisitor>(
424428
Stmt::ClassDef(class_def) => {
425429
let key = KeyClass(ShortIdentifier::new(&class_def.name));
426430
let class_scope = if let Some(idx) = module_context
431+
.answers_context
427432
.bindings
428433
.key_to_idx_hashed_opt(Hashed::new(&key))
429434
{
430435
let class = module_context
436+
.answers_context
431437
.answers
432438
.get_idx(idx)
433439
.unwrap()
@@ -532,11 +538,11 @@ pub fn visit_module_ast<V: AstScopedVisitor>(
532538
let mut scopes = Scopes {
533539
stack: vec![Scope::TopLevel],
534540
};
535-
visitor.enter_toplevel_scope(&module_context.ast, &scopes);
541+
visitor.enter_toplevel_scope(&module_context.answers_context.ast, &scopes);
536542
visitor.on_scope_update(&scopes);
537-
for stmt in &module_context.ast.body {
543+
for stmt in &module_context.answers_context.ast.body {
538544
visit_statement(stmt, visitor, &mut scopes, module_context);
539545
}
540-
visitor.exit_toplevel_scope(&module_context.ast, &scopes);
546+
visitor.exit_toplevel_scope(&module_context.answers_context.ast, &scopes);
541547
scopes
542548
}

pyrefly/lib/report/pysa/call_graph.rs

Lines changed: 144 additions & 75 deletions
Large diffs are not rendered by default.

pyrefly/lib/report/pysa/captured_variable.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,12 @@ impl<'a> DefinitionToFunctionMapVisitor<'a> {
121121
fn bind_name(&mut self, key: Key, scopes: &Scopes) {
122122
if let Some(idx) = self
123123
.module_context
124+
.answers_context
124125
.bindings
125126
.key_to_idx_hashed_opt(Hashed::new(&key))
126127
&& let Some(current_function) = scopes.current_exported_function(
127-
self.module_context.module_id,
128-
self.module_context.module_info.name(),
128+
self.module_context.answers_context.module_id,
129+
self.module_context.answers_context.module_info.name(),
129130
&SCOPE_EXPORTED_FUNCTION_FLAGS,
130131
)
131132
{
@@ -220,9 +221,10 @@ impl<'a> CapturedVariableVisitor<'a> {
220221
fn get_definition_from_usage(&self, key: Key) -> Option<FunctionRef> {
221222
let idx = self
222223
.module_context
224+
.answers_context
223225
.bindings
224226
.key_to_idx_hashed_opt(Hashed::new(&key))?;
225-
let binding = self.module_context.bindings.get(idx);
227+
let binding = self.module_context.answers_context.bindings.get(idx);
226228
match binding {
227229
Binding::Forward(definition_idx) | Binding::ForwardToFirstUse(definition_idx) => {
228230
self.get_definition_from_idx(
@@ -257,7 +259,7 @@ impl<'a> CapturedVariableVisitor<'a> {
257259
}
258260
depth += 1;
259261

260-
let binding = self.module_context.bindings.get(idx);
262+
let binding = self.module_context.answers_context.bindings.get(idx);
261263
match binding {
262264
Binding::Forward(idx)
263265
| Binding::ForwardToFirstUse(idx)
@@ -280,8 +282,8 @@ impl<'a> CapturedVariableVisitor<'a> {
280282
impl<'a> AstScopedVisitor for CapturedVariableVisitor<'a> {
281283
fn on_scope_update(&mut self, scopes: &Scopes) {
282284
self.current_exported_function = scopes.current_exported_function(
283-
self.module_context.module_id,
284-
self.module_context.module_info.name(),
285+
self.module_context.answers_context.module_id,
286+
self.module_context.answers_context.module_info.name(),
285287
&SCOPE_EXPORTED_FUNCTION_FLAGS,
286288
);
287289
}
@@ -387,7 +389,7 @@ pub fn export_captured_variables_for_module(
387389
context: &ModuleContext,
388390
) -> HashMap<FunctionRef, Vec<CapturedVariableRef<FunctionRef>>> {
389391
captured_variables
390-
.get_for_module(context.module_id)
392+
.get_for_module(context.answers_context.module_id)
391393
.unwrap()
392394
.clone()
393395
.0

pyrefly/lib/report/pysa/class.rs

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use crate::report::pysa::ModuleContext;
4444
use crate::report::pysa::call_graph::Target;
4545
use crate::report::pysa::call_graph::resolve_decorator_callees;
4646
use crate::report::pysa::collect::CollectNoDuplicateKeys;
47+
use crate::report::pysa::context::ModuleAnswersContext;
4748
use crate::report::pysa::function::FunctionBaseDefinition;
4849
use crate::report::pysa::function::FunctionRef;
4950
use crate::report::pysa::function::WholeProgramFunctionDefinitions;
@@ -232,7 +233,7 @@ impl ClassDefinition {
232233
}
233234
}
234235

235-
pub fn get_all_classes(context: &ModuleContext) -> impl Iterator<Item = Class> {
236+
pub fn get_all_classes(context: &ModuleAnswersContext) -> impl Iterator<Item = Class> {
236237
context
237238
.bindings
238239
.keys::<KeyClass>()
@@ -242,14 +243,29 @@ pub fn get_all_classes(context: &ModuleContext) -> impl Iterator<Item = Class> {
242243
pub fn get_class_field_from_current_class_only(
243244
class: &Class,
244245
field_name: &Name,
245-
context: &ModuleContext,
246+
context: &ModuleAnswersContext,
246247
) -> Option<Arc<ClassField>> {
247-
context
248-
.transaction
249-
.ad_hoc_solve(&context.handle, "pysa_class_field", |solver| {
250-
solver.get_field_from_current_class_only(class, field_name)
251-
})
252-
.unwrap()
248+
// This inlines the logic from `AnswersSolver::get_field_from_current_class_only`,
249+
// `get_non_synthesized_field_from_current_class_only`, and
250+
// `get_synthesized_field_from_current_class_only`.
251+
assert!(class.module() == &context.module_info);
252+
253+
// Non-synthesized field: check class fields list, then look up the answer.
254+
let class_fields = &context.bindings.metadata().get_class(class.index()).fields;
255+
if class_fields.contains(field_name) {
256+
let key = KeyClassField(class.index(), field_name.clone());
257+
if let Some(idx) = context.bindings.key_to_idx_hashed_opt(Hashed::new(&key))
258+
&& let Some(field) = context.answers.get_idx(idx)
259+
{
260+
return Some(field);
261+
}
262+
}
263+
264+
// Synthesized field (e.g., dataclass fields).
265+
let key = KeyClassSynthesizedFields(class.index());
266+
let idx = context.bindings.key_to_idx_hashed_opt(Hashed::new(&key))?;
267+
let synthesized_fields = context.answers.get_idx(idx)?;
268+
Some(synthesized_fields.get(field_name)?.inner.dupe())
253269
}
254270

255271
pub fn get_super_class_member(
@@ -260,16 +276,18 @@ pub fn get_super_class_member(
260276
) -> Option<WithDefiningClass<Arc<ClassField>>> {
261277
context
262278
.transaction
263-
.ad_hoc_solve(&context.handle, "pysa_super_class_member", |solver| {
264-
solver.get_super_class_member(class, start_lookup_cls, field_name)
265-
})
279+
.ad_hoc_solve(
280+
&context.answers_context.handle,
281+
"pysa_super_class_member",
282+
|solver| solver.get_super_class_member(class, start_lookup_cls, field_name),
283+
)
266284
.flatten()
267285
}
268286

269287
pub fn get_class_field_declaration<'a>(
270288
class: &Class,
271289
field_name: &Name,
272-
context: &'a ModuleContext,
290+
context: &'a ModuleAnswersContext,
273291
) -> Option<&'a BindingClassField> {
274292
assert_eq!(class.module(), &context.module_info);
275293
let key_class_field = KeyClassField(class.index(), field_name.clone());
@@ -280,7 +298,7 @@ pub fn get_class_field_declaration<'a>(
280298
.map(|idx| context.bindings.get(idx))
281299
}
282300

283-
pub fn get_class_mro(class: &Class, context: &ModuleContext) -> Arc<ClassMro> {
301+
pub fn get_class_mro(class: &Class, context: &ModuleAnswersContext) -> Arc<ClassMro> {
284302
assert_eq!(class.module(), &context.module_info);
285303
context
286304
.answers
@@ -290,7 +308,7 @@ pub fn get_class_mro(class: &Class, context: &ModuleContext) -> Arc<ClassMro> {
290308

291309
pub fn get_class_fields<'a>(
292310
class: &'a Class,
293-
context: &'a ModuleContext<'a>,
311+
context: &'a ModuleAnswersContext,
294312
) -> impl Iterator<Item = (Cow<'a, Name>, Arc<ClassField>)> {
295313
let class_fields = context
296314
.bindings
@@ -368,11 +386,11 @@ fn export_class_fields(
368386
context: &ModuleContext,
369387
ann_assign_map: &AnnAssignMap,
370388
) -> HashMap<Name, PysaClassField> {
371-
assert_eq!(class.module(), &context.module_info);
372-
get_class_fields(class, context)
389+
assert_eq!(class.module(), &context.answers_context.module_info);
390+
get_class_fields(class, &context.answers_context)
373391
.filter(|(_, field)| !is_callable_like(&field.ty()))
374392
.filter_map(|(name, field)| {
375-
let field_binding = get_class_field_declaration(class, &name, context);
393+
let field_binding = get_class_field_declaration(class, &name, &context.answers_context);
376394

377395
let explicit_annotation = match field_binding {
378396
Some(BindingClassField {
@@ -389,7 +407,7 @@ fn export_class_fields(
389407
}) => *annotation,
390408
_ => None,
391409
}
392-
.map(|idx| context.bindings.idx_to_key(idx))
410+
.map(|idx| context.answers_context.bindings.idx_to_key(idx))
393411
.and_then(|key_annotation| match key_annotation {
394412
// We want to export the annotation as it is in the source code.
395413
// We cannot use the answer for `key_annotation` (which wraps a `Type`),
@@ -398,11 +416,19 @@ fn export_class_fields(
398416
KeyAnnotation::Annotation(identifier) => ann_assign_map
399417
.get(identifier.range().start())
400418
.map(|annotation_range| {
401-
context.module_info.code_at(*annotation_range).to_owned()
419+
context
420+
.answers_context
421+
.module_info
422+
.code_at(*annotation_range)
423+
.to_owned()
402424
}),
403-
KeyAnnotation::AttrAnnotation(range) => {
404-
Some(context.module_info.code_at(*range).to_owned())
405-
}
425+
KeyAnnotation::AttrAnnotation(range) => Some(
426+
context
427+
.answers_context
428+
.module_info
429+
.code_at(*range)
430+
.to_owned(),
431+
),
406432
_ => None,
407433
});
408434

@@ -421,7 +447,10 @@ fn export_class_fields(
421447
PysaClassField {
422448
type_: PysaType::from_type(&field.ty(), context),
423449
explicit_annotation,
424-
location: Some(PysaLocation::from_text_range(*range, &context.module_info)),
450+
location: Some(PysaLocation::from_text_range(
451+
*range,
452+
&context.answers_context.module_info,
453+
)),
425454
declaration_kind: PysaClassFieldDeclaration::from(definition),
426455
},
427456
)),
@@ -444,8 +473,8 @@ fn find_definition_ast<'a>(
444473
class: &Class,
445474
context: &'a ModuleContext<'a>,
446475
) -> Option<&'a StmtClassDef> {
447-
assert_eq!(class.module(), &context.module_info);
448-
Ast::locate_node(&context.ast, class.qname().range().start())
476+
assert_eq!(class.module(), &context.answers_context.module_info);
477+
Ast::locate_node(&context.answers_context.ast, class.qname().range().start())
449478
.iter()
450479
.find_map(|node| match node {
451480
AnyNodeRef::StmtClassDef(stmt) if stmt.name.range == class.qname().range() => {
@@ -461,7 +490,7 @@ fn get_decorator_callees(
461490
function_base_definitions: &WholeProgramFunctionDefinitions<FunctionBaseDefinition>,
462491
context: &ModuleContext,
463492
) -> HashMap<PysaLocation, Vec<Target<FunctionRef>>> {
464-
assert_eq!(class.module(), &context.module_info);
493+
assert_eq!(class.module(), &context.answers_context.module_info);
465494
if let Some(class_def) = find_definition_ast(class, context) {
466495
resolve_decorator_callees(
467496
&class_def.decorator_list,
@@ -480,24 +509,35 @@ pub fn export_all_classes(
480509
context: &ModuleContext,
481510
) -> HashMap<PysaLocation, ClassDefinition> {
482511
let mut class_definitions = HashMap::new();
483-
let ann_assign_map = AnnAssignMap::build(&context.ast);
512+
let ann_assign_map = AnnAssignMap::build(&context.answers_context.ast);
484513

485-
for class_idx in context.bindings.keys::<KeyClass>() {
514+
for class_idx in context.answers_context.bindings.keys::<KeyClass>() {
486515
let class = context
516+
.answers_context
487517
.answers
488518
.get_idx(class_idx)
489519
.unwrap()
490520
.0
491521
.dupe()
492522
.unwrap();
493523
let class_index = class.index();
494-
let parent = get_scope_parent(&context.ast, &context.module_info, class.qname().range());
524+
let parent = get_scope_parent(
525+
&context.answers_context.ast,
526+
&context.answers_context.module_info,
527+
class.qname().range(),
528+
);
495529
let metadata = context
530+
.answers_context
496531
.answers
497-
.get_idx(context.bindings.key_to_idx(&KeyClassMetadata(class_index)))
532+
.get_idx(
533+
context
534+
.answers_context
535+
.bindings
536+
.key_to_idx(&KeyClassMetadata(class_index)),
537+
)
498538
.unwrap();
499539

500-
let is_synthesized = match context.bindings.get(class_idx) {
540+
let is_synthesized = match context.answers_context.bindings.get(class_idx) {
501541
BindingClass::FunctionalClassDef(_, _, _) => true,
502542
BindingClass::ClassDef(_) => false,
503543
};
@@ -510,7 +550,7 @@ pub fn export_all_classes(
510550
.map(|base_class| ClassRef::from_class(base_class, context.module_ids))
511551
.collect::<Vec<_>>();
512552

513-
let mro = match &*get_class_mro(&class, context) {
553+
let mro = match &*get_class_mro(&class, &context.answers_context) {
514554
ClassMro::Resolved(mro) => PysaClassMro::Resolved(
515555
mro.iter()
516556
.map(|class_type| {
@@ -545,7 +585,10 @@ pub fn export_all_classes(
545585
assert!(
546586
class_definitions
547587
.insert(
548-
PysaLocation::from_text_range(class.qname().range(), &context.module_info),
588+
PysaLocation::from_text_range(
589+
class.qname().range(),
590+
&context.answers_context.module_info
591+
),
549592
class_definition
550593
)
551594
.is_none(),

0 commit comments

Comments
 (0)