@@ -44,6 +44,7 @@ use crate::report::pysa::ModuleContext;
4444use crate :: report:: pysa:: call_graph:: Target ;
4545use crate :: report:: pysa:: call_graph:: resolve_decorator_callees;
4646use crate :: report:: pysa:: collect:: CollectNoDuplicateKeys ;
47+ use crate :: report:: pysa:: context:: ModuleAnswersContext ;
4748use crate :: report:: pysa:: function:: FunctionBaseDefinition ;
4849use crate :: report:: pysa:: function:: FunctionRef ;
4950use crate :: report:: pysa:: function:: WholeProgramFunctionDefinitions ;
@@ -232,7 +233,7 @@ impl ClassDefinition {
232233 }
233234}
234235
235- pub fn get_all_classes ( context : & ModuleContext ) -> impl Iterator < Item = Class > {
236+ pub fn get_all_classes ( context : & ModuleAnswersContext ) -> impl Iterator < Item = Class > {
236237 context
237238 . bindings
238239 . keys :: < KeyClass > ( )
@@ -242,14 +243,29 @@ pub fn get_all_classes(context: &ModuleContext) -> impl Iterator<Item = Class> {
242243pub fn get_class_field_from_current_class_only (
243244 class : & Class ,
244245 field_name : & Name ,
245- context : & ModuleContext ,
246+ context : & ModuleAnswersContext ,
246247) -> Option < Arc < ClassField > > {
247- context
248- . transaction
249- . ad_hoc_solve ( & context. handle , "pysa_class_field" , |solver| {
250- solver. get_field_from_current_class_only ( class, field_name)
251- } )
252- . unwrap ( )
248+ // This inlines the logic from `AnswersSolver::get_field_from_current_class_only`,
249+ // `get_non_synthesized_field_from_current_class_only`, and
250+ // `get_synthesized_field_from_current_class_only`.
251+ assert ! ( class. module( ) == & context. module_info) ;
252+
253+ // Non-synthesized field: check class fields list, then look up the answer.
254+ let class_fields = & context. bindings . metadata ( ) . get_class ( class. index ( ) ) . fields ;
255+ if class_fields. contains ( field_name) {
256+ let key = KeyClassField ( class. index ( ) , field_name. clone ( ) ) ;
257+ if let Some ( idx) = context. bindings . key_to_idx_hashed_opt ( Hashed :: new ( & key) )
258+ && let Some ( field) = context. answers . get_idx ( idx)
259+ {
260+ return Some ( field) ;
261+ }
262+ }
263+
264+ // Synthesized field (e.g., dataclass fields).
265+ let key = KeyClassSynthesizedFields ( class. index ( ) ) ;
266+ let idx = context. bindings . key_to_idx_hashed_opt ( Hashed :: new ( & key) ) ?;
267+ let synthesized_fields = context. answers . get_idx ( idx) ?;
268+ Some ( synthesized_fields. get ( field_name) ?. inner . dupe ( ) )
253269}
254270
255271pub fn get_super_class_member (
@@ -260,16 +276,18 @@ pub fn get_super_class_member(
260276) -> Option < WithDefiningClass < Arc < ClassField > > > {
261277 context
262278 . transaction
263- . ad_hoc_solve ( & context. handle , "pysa_super_class_member" , |solver| {
264- solver. get_super_class_member ( class, start_lookup_cls, field_name)
265- } )
279+ . ad_hoc_solve (
280+ & context. answers_context . handle ,
281+ "pysa_super_class_member" ,
282+ |solver| solver. get_super_class_member ( class, start_lookup_cls, field_name) ,
283+ )
266284 . flatten ( )
267285}
268286
269287pub fn get_class_field_declaration < ' a > (
270288 class : & Class ,
271289 field_name : & Name ,
272- context : & ' a ModuleContext ,
290+ context : & ' a ModuleAnswersContext ,
273291) -> Option < & ' a BindingClassField > {
274292 assert_eq ! ( class. module( ) , & context. module_info) ;
275293 let key_class_field = KeyClassField ( class. index ( ) , field_name. clone ( ) ) ;
@@ -280,7 +298,7 @@ pub fn get_class_field_declaration<'a>(
280298 . map ( |idx| context. bindings . get ( idx) )
281299}
282300
283- pub fn get_class_mro ( class : & Class , context : & ModuleContext ) -> Arc < ClassMro > {
301+ pub fn get_class_mro ( class : & Class , context : & ModuleAnswersContext ) -> Arc < ClassMro > {
284302 assert_eq ! ( class. module( ) , & context. module_info) ;
285303 context
286304 . answers
@@ -290,7 +308,7 @@ pub fn get_class_mro(class: &Class, context: &ModuleContext) -> Arc<ClassMro> {
290308
291309pub fn get_class_fields < ' a > (
292310 class : & ' a Class ,
293- context : & ' a ModuleContext < ' a > ,
311+ context : & ' a ModuleAnswersContext ,
294312) -> impl Iterator < Item = ( Cow < ' a , Name > , Arc < ClassField > ) > {
295313 let class_fields = context
296314 . bindings
@@ -368,11 +386,11 @@ fn export_class_fields(
368386 context : & ModuleContext ,
369387 ann_assign_map : & AnnAssignMap ,
370388) -> HashMap < Name , PysaClassField > {
371- assert_eq ! ( class. module( ) , & context. module_info) ;
372- get_class_fields ( class, context)
389+ assert_eq ! ( class. module( ) , & context. answers_context . module_info) ;
390+ get_class_fields ( class, & context. answers_context )
373391 . filter ( |( _, field) | !is_callable_like ( & field. ty ( ) ) )
374392 . filter_map ( |( name, field) | {
375- let field_binding = get_class_field_declaration ( class, & name, context) ;
393+ let field_binding = get_class_field_declaration ( class, & name, & context. answers_context ) ;
376394
377395 let explicit_annotation = match field_binding {
378396 Some ( BindingClassField {
@@ -389,7 +407,7 @@ fn export_class_fields(
389407 } ) => * annotation,
390408 _ => None ,
391409 }
392- . map ( |idx| context. bindings . idx_to_key ( idx) )
410+ . map ( |idx| context. answers_context . bindings . idx_to_key ( idx) )
393411 . and_then ( |key_annotation| match key_annotation {
394412 // We want to export the annotation as it is in the source code.
395413 // We cannot use the answer for `key_annotation` (which wraps a `Type`),
@@ -398,11 +416,19 @@ fn export_class_fields(
398416 KeyAnnotation :: Annotation ( identifier) => ann_assign_map
399417 . get ( identifier. range ( ) . start ( ) )
400418 . map ( |annotation_range| {
401- context. module_info . code_at ( * annotation_range) . to_owned ( )
419+ context
420+ . answers_context
421+ . module_info
422+ . code_at ( * annotation_range)
423+ . to_owned ( )
402424 } ) ,
403- KeyAnnotation :: AttrAnnotation ( range) => {
404- Some ( context. module_info . code_at ( * range) . to_owned ( ) )
405- }
425+ KeyAnnotation :: AttrAnnotation ( range) => Some (
426+ context
427+ . answers_context
428+ . module_info
429+ . code_at ( * range)
430+ . to_owned ( ) ,
431+ ) ,
406432 _ => None ,
407433 } ) ;
408434
@@ -421,7 +447,10 @@ fn export_class_fields(
421447 PysaClassField {
422448 type_ : PysaType :: from_type ( & field. ty ( ) , context) ,
423449 explicit_annotation,
424- location : Some ( PysaLocation :: from_text_range ( * range, & context. module_info ) ) ,
450+ location : Some ( PysaLocation :: from_text_range (
451+ * range,
452+ & context. answers_context . module_info ,
453+ ) ) ,
425454 declaration_kind : PysaClassFieldDeclaration :: from ( definition) ,
426455 } ,
427456 ) ) ,
@@ -444,8 +473,8 @@ fn find_definition_ast<'a>(
444473 class : & Class ,
445474 context : & ' a ModuleContext < ' a > ,
446475) -> Option < & ' a StmtClassDef > {
447- assert_eq ! ( class. module( ) , & context. module_info) ;
448- Ast :: locate_node ( & context. ast , class. qname ( ) . range ( ) . start ( ) )
476+ assert_eq ! ( class. module( ) , & context. answers_context . module_info) ;
477+ Ast :: locate_node ( & context. answers_context . ast , class. qname ( ) . range ( ) . start ( ) )
449478 . iter ( )
450479 . find_map ( |node| match node {
451480 AnyNodeRef :: StmtClassDef ( stmt) if stmt. name . range == class. qname ( ) . range ( ) => {
@@ -461,7 +490,7 @@ fn get_decorator_callees(
461490 function_base_definitions : & WholeProgramFunctionDefinitions < FunctionBaseDefinition > ,
462491 context : & ModuleContext ,
463492) -> HashMap < PysaLocation , Vec < Target < FunctionRef > > > {
464- assert_eq ! ( class. module( ) , & context. module_info) ;
493+ assert_eq ! ( class. module( ) , & context. answers_context . module_info) ;
465494 if let Some ( class_def) = find_definition_ast ( class, context) {
466495 resolve_decorator_callees (
467496 & class_def. decorator_list ,
@@ -480,24 +509,35 @@ pub fn export_all_classes(
480509 context : & ModuleContext ,
481510) -> HashMap < PysaLocation , ClassDefinition > {
482511 let mut class_definitions = HashMap :: new ( ) ;
483- let ann_assign_map = AnnAssignMap :: build ( & context. ast ) ;
512+ let ann_assign_map = AnnAssignMap :: build ( & context. answers_context . ast ) ;
484513
485- for class_idx in context. bindings . keys :: < KeyClass > ( ) {
514+ for class_idx in context. answers_context . bindings . keys :: < KeyClass > ( ) {
486515 let class = context
516+ . answers_context
487517 . answers
488518 . get_idx ( class_idx)
489519 . unwrap ( )
490520 . 0
491521 . dupe ( )
492522 . unwrap ( ) ;
493523 let class_index = class. index ( ) ;
494- let parent = get_scope_parent ( & context. ast , & context. module_info , class. qname ( ) . range ( ) ) ;
524+ let parent = get_scope_parent (
525+ & context. answers_context . ast ,
526+ & context. answers_context . module_info ,
527+ class. qname ( ) . range ( ) ,
528+ ) ;
495529 let metadata = context
530+ . answers_context
496531 . answers
497- . get_idx ( context. bindings . key_to_idx ( & KeyClassMetadata ( class_index) ) )
532+ . get_idx (
533+ context
534+ . answers_context
535+ . bindings
536+ . key_to_idx ( & KeyClassMetadata ( class_index) ) ,
537+ )
498538 . unwrap ( ) ;
499539
500- let is_synthesized = match context. bindings . get ( class_idx) {
540+ let is_synthesized = match context. answers_context . bindings . get ( class_idx) {
501541 BindingClass :: FunctionalClassDef ( _, _, _) => true ,
502542 BindingClass :: ClassDef ( _) => false ,
503543 } ;
@@ -510,7 +550,7 @@ pub fn export_all_classes(
510550 . map ( |base_class| ClassRef :: from_class ( base_class, context. module_ids ) )
511551 . collect :: < Vec < _ > > ( ) ;
512552
513- let mro = match & * get_class_mro ( & class, context) {
553+ let mro = match & * get_class_mro ( & class, & context. answers_context ) {
514554 ClassMro :: Resolved ( mro) => PysaClassMro :: Resolved (
515555 mro. iter ( )
516556 . map ( |class_type| {
@@ -545,7 +585,10 @@ pub fn export_all_classes(
545585 assert ! (
546586 class_definitions
547587 . insert(
548- PysaLocation :: from_text_range( class. qname( ) . range( ) , & context. module_info) ,
588+ PysaLocation :: from_text_range(
589+ class. qname( ) . range( ) ,
590+ & context. answers_context. module_info
591+ ) ,
549592 class_definition
550593 )
551594 . is_none( ) ,
0 commit comments