Skip to content

Commit f7a1b3c

Browse files
Add mypy configuration.
It will be some time before we can enforce this. I triaged some initial issues and sent llvm/llvm-project#66723 upstream to cover some mis-aligned things. We should triage again once that lands.
1 parent 7d1d0d0 commit f7a1b3c

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[mypy]
2+
3+
mypy_path = $MYPY_CONFIG_FILE_DIR/python
4+
packages = shark_turbine.aot,shark_turbine.dynamo,shark_turbine.support

python/shark_turbine/aot/builtins/jittable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
8686

8787
# Already materialized.
8888
logger.debug("Resolved defined global for literal %r", mapping)
89-
materialized_global: MaterializedGlobal = mapping.value
89+
materialized_global: MaterializedGlobal = mapping.value # type: ignore
9090

9191
# Clone the global into the import module (so that our symbol refs are
9292
# legal). Note that the merger will ignore these since they already

python/shark_turbine/aot/support/ir_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
MLIR_TYPE_ASM_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_MLIR_TYPE_ASM.items()}
5050

5151
# When emitting constants, we have to create native IREE types.
52-
TORCH_DTYPE_TO_IREE_TYPE: Dict[str, Callable[[], IrType]] = {
52+
TORCH_DTYPE_TO_IREE_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
5353
torch.float16: lambda: F16Type.get(),
5454
torch.bfloat16: lambda: BF16Type.get(),
5555
torch.float32: lambda: F32Type.get(),
@@ -167,7 +167,8 @@ def create_tensor_global(
167167
if initialize:
168168
detached_tensor = t.detach().contiguous().cpu()
169169
array = np.array(detached_tensor)
170-
contents = memoryview(array)
170+
# We know that a Numpy array is a ReadableBuffer so ignore type error.
171+
contents = memoryview(array) # type: ignore
171172
# TODO: Add resource elements to Python API and use that.
172173
elements_attr = DenseElementsAttr.get(contents, type=tensor_type)
173174
attrs["initial_value"] = elements_attr
@@ -200,10 +201,10 @@ def __init__(
200201
self.func_op = func_op
201202
self.context = func_op.context
202203
self.ip = InsertionPoint(self.func_op.entry_block)
203-
self.return_types = None
204+
self.return_types: Optional[Sequence[IrType]] = None
204205
self.loc = self.func_op.location
205206

206-
def emit_return(self, *ir_values: Sequence[Value]):
207+
def emit_return(self, *ir_values: Value):
207208
with self.loc, self.ip:
208209
func_d.ReturnOp(ir_values)
209210
# Check or rewrite the function return type.

python/shark_turbine/aot/support/procedural.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def current_ir_trace() -> IrTrace:
9191
class Intrinsic:
9292
"""Objects which interact natively with the tracing system implement this."""
9393

94-
__slots__ = []
94+
__slots__: List[str] = []
9595

9696
def resolve_ir_values(self, proc_trace: "IrTrace") -> Sequence[Value]:
9797
raise NotImplementedError(
@@ -125,7 +125,7 @@ def __call__(self, *args, **kwargs):
125125
class AbstractIntrinsic:
126126
"""Base class for descriptor types that can be converted to Python proxies."""
127127

128-
__slots__ = []
128+
__slots__: List[str] = []
129129

130130
def create_intrinsic(self, value: Value) -> Intrinsic:
131131
"""Creates a proxy object that can flow through a procedural trace."""
@@ -192,7 +192,7 @@ def abstractify_single_value(value) -> AbstractTypedef:
192192
if isinstance(value, AbstractTypedef):
193193
return value
194194
if isinstance(value, Abstractifiable):
195-
return value.get_abstract_typedef()
195+
return value.abstractify()
196196
if isinstance(value, torch.Tensor):
197197
return AbstractTensor(*value.shape, dtype=value.dtype)
198198
raise TypeError(
@@ -494,7 +494,7 @@ def trace_py_func(self, py_f: Callable):
494494
self.emit_return()
495495
else:
496496
flat_return_py_values, schema = tree_flatten(return_py_value)
497-
flat_return_ir_values = []
497+
flat_return_ir_values: List[Value] = []
498498
for py_value in flat_return_py_values:
499499
flat_return_ir_values.extend(convert_py_value_to_ir(self, py_value))
500500
self.func_op.attributes["torch.return_schema"] = StringAttr.get(

0 commit comments

Comments
 (0)