@@ -287,6 +287,10 @@ class type
287287 raise ValueError (
288288 f"Cannot look up { base_class_wanted } . Cannot tell what it is."
289289 )
290+ if not isinstance (name , str ):
291+ raise ValueError (
292+ f"Cannot look up a { type (name )} in the registry. Got { name } ."
293+ )
290294 result = self ._mapping [base_class ].get (name )
291295 if result is None :
292296 raise ValueError (f"{ name } has not been registered." )
@@ -446,6 +450,11 @@ def create_pluggable(self, type_name, args):
446450 setattr (self , name , None )
447451 return
448452
453+ if not isinstance (type_name , str ):
454+ raise ValueError (
455+ f"A { type (type_name )} was received as the type of { name } ."
456+ + f" Perhaps this is from { name } { TYPE_SUFFIX } ?"
457+ )
449458 chosen_class = registry .get (type_ , type_name )
450459 if self ._known_implementations .get (type_name , chosen_class ) is not chosen_class :
451460 # If this warning is raised, it means that a new definition of
@@ -514,7 +523,10 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
514523 # because in practice get_default_args_field is used for
515524 # separate types than the outer type.
516525
517- out : DictConfig = OmegaConf .structured (C )
526+ try :
527+ out : DictConfig = OmegaConf .structured (C )
528+ except Exception as e :
529+ raise ValueError (f"OmegaConf.structured({ C } ) failed" ) from e
518530 exclude = getattr (C , "_processed_members" , ())
519531 with open_dict (out ):
520532 for field in exclude :
@@ -534,7 +546,11 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
534546 f"Cannot get args for { C } . Was enable_get_default_args forgotten?"
535547 )
536548
537- return OmegaConf .structured (dataclass )
549+ try :
550+ out : DictConfig = OmegaConf .structured (dataclass )
551+ except Exception as e :
552+ raise ValueError (f"OmegaConf.structured failed for { dataclass_name } " ) from e
553+ return out
538554
539555
540556def _dataclass_name_for_function (C : Any ) -> str :
@@ -546,22 +562,21 @@ def _dataclass_name_for_function(C: Any) -> str:
546562 return name
547563
548564
549- def enable_get_default_args (C : Any , * , overwrite : bool = True ) -> None :
565+ def _field_annotations_for_default_args (
566+ C : Any ,
567+ ) -> List [Tuple [str , Any , dataclasses .Field ]]:
550568 """
551569 If C is a function or a plain class with an __init__ function,
552- and you want get_default_args(C) to work, then add
553- `enable_get_default_args(C)` straight after the definition of C.
554- This makes a dataclass corresponding to the default arguments of C
555- and stores it in the same module as C.
570+ return the fields which `enable_get_default_args(C)` will need
571+ to make a dataclass with.
556572
557573 Args:
558574 C: a function, or a class with an __init__ function. Must
559575 have types for all its defaulted args.
560- overwrite: whether to allow calling this a second time on
561- the same function.
576+
577+ Returns:
578+ a list of fields for a dataclass.
562579 """
563- if not inspect .isfunction (C ) and not inspect .isclass (C ):
564- raise ValueError (f"Unexpected { C } " )
565580
566581 field_annotations = []
567582 for pname , defval in _params_iter (C ):
@@ -572,8 +587,8 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
572587
573588 if defval .annotation == inspect ._empty :
574589 raise ValueError (
575- "All arguments of the input callable have to be typed. "
576- + f" Argument '{ pname } ' does not have a type annotation."
590+ "All arguments of the input to enable_get_default_args have to"
591+ f" be typed. Argument '{ pname } ' does not have a type annotation."
577592 )
578593
579594 _ , annotation = _resolve_optional (defval .annotation )
@@ -591,6 +606,28 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
591606 field_ = dataclasses .field (default = default )
592607 field_annotations .append ((pname , defval .annotation , field_ ))
593608
609+ return field_annotations
610+
611+
612+ def enable_get_default_args (C : Any , * , overwrite : bool = True ) -> None :
613+ """
614+ If C is a function or a plain class with an __init__ function,
615+ and you want get_default_args(C) to work, then add
616+ `enable_get_default_args(C)` straight after the definition of C.
617+ This makes a dataclass corresponding to the default arguments of C
618+ and stores it in the same module as C.
619+
620+ Args:
621+ C: a function, or a class with an __init__ function. Must
622+ have types for all its defaulted args.
623+ overwrite: whether to allow calling this a second time on
624+ the same function.
625+ """
626+ if not inspect .isfunction (C ) and not inspect .isclass (C ):
627+ raise ValueError (f"Unexpected { C } " )
628+
629+ field_annotations = _field_annotations_for_default_args (C )
630+
594631 name = _dataclass_name_for_function (C )
595632 module = sys .modules [C .__module__ ]
596633 if hasattr (module , name ):
@@ -767,7 +804,7 @@ def create_x_impl(self, enabled, args):
767804
768805 Also adds the following class members, unannotated so that dataclass
769806 ignores them.
770- - _creation_functions: Tuple[str] of all the create_ functions,
807+ - _creation_functions: Tuple[str, ... ] of all the create_ functions,
771808 including those from base classes (not the create_x_impl ones).
772809 - _known_implementations: Dict[str, Type] containing the classes which
773810 have been found from the registry.
@@ -945,7 +982,7 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
945982 return underlying , _ProcessType .OPTIONAL_CONFIGURABLE
946983
947984 if not isinstance (type_ , type ):
948- # e.g. any other Union or Tuple
985+ # e.g. any other Union or Tuple. Or ClassVar.
949986 return
950987
951988 if issubclass (type_ , ReplaceableBase ) and ReplaceableBase in type_ .__bases__ :
0 commit comments