File tree 1 file changed +2
-2
lines changed 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -388,7 +388,7 @@ class MyTensor(torch.Tensor):
388
388
389
389
raise NotImplementedError (f"{ cls .__name__ } dispatch: attempting to run unimplemented operator/function: { func } " )
390
390
391
- def _register_layout_cls (cls : Callable , layout_type_class : type ( LayoutType ) ):
391
+ def _register_layout_cls (cls : Callable , layout_type_class : Callable ):
392
392
"""Helper function for layout registrations, this is used to implement
393
393
register_layout_cls decorator for each tensor subclass, see aqt.py for example usage
394
394
@@ -414,7 +414,7 @@ def decorator(layout_class):
414
414
return layout_class
415
415
return decorator
416
416
417
- def _get_layout_tensor_constructor (cls : Callable , layout_type_class : type ( LayoutType ) ) -> Callable :
417
+ def _get_layout_tensor_constructor (cls : Callable , layout_type_class : Callable ) -> Callable :
418
418
"""Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class`
419
419
`layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
420
420
You can’t perform that action at this time.
0 commit comments