99import json
1010import typing
1111from dataclasses import dataclass , field
12- from typing import List
12+ from typing import Any , Dict , List , Optional
1313
1414import executorch .exir .memory as memory
1515import torch
@@ -52,7 +52,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
5252 allocations at that timestep.
5353 """
5454 nodes = graph .nodes
55- memory_timeline = [None ] * len (nodes )
55+ memory_timeline : List [ Optional [ MemoryTimeline ]] = [None ] * len (nodes )
5656 for _ , node in enumerate (nodes ):
5757 if node .op == "output" :
5858 continue
@@ -73,9 +73,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
7373 fqn = _get_module_hierarchy (node )
7474 for j in range (start , end + 1 ):
7575 if memory_timeline [j ] is None :
76- # pyre-ignore
7776 memory_timeline [j ] = MemoryTimeline ()
78- # pyre-ignore
7977 memory_timeline [j ].allocations .append (
8078 Allocation (
8179 node .name ,
@@ -87,8 +85,7 @@ def create_tensor_allocation_info(graph: torch.fx.Graph) -> List[MemoryTimeline]
8785 stack_trace ,
8886 )
8987 )
90- # pyre-ignore
91- return memory_timeline
88+ return memory_timeline # type: ignore[return-value]
9289
9390
9491def _validate_memory_planning_is_done (exported_program : ExportedProgram ):
@@ -129,7 +126,7 @@ def generate_memory_trace(
129126
130127 memory_timeline = create_tensor_allocation_info (exported_program .graph )
131128 root = {}
132- trace_events = []
129+ trace_events : List [ Dict [ str , Any ]] = []
133130 root ["traceEvents" ] = trace_events
134131
135132 tid = 0
@@ -138,7 +135,7 @@ def generate_memory_trace(
138135 if memory_timeline_event is None :
139136 continue
140137 for allocation in memory_timeline_event .allocations :
141- e = {}
138+ e : Dict [ str , Any ] = {}
142139 e ["name" ] = allocation .name
143140 e ["cat" ] = "memory_allocation"
144141 e ["ph" ] = "X"
0 commit comments