Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions scripts/optimizer/processors/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,44 @@ def __next__(self) -> list:
class BatchGenerator:
def __init__(self):
self.tracked_elements = []
self.batches = []
self.batches = [1, 2, 4, 8, 16, 32]
self.batch_groups = []
self.pipeline = []
self.first_iteration = True

def init_pipeline(self, pipeline):
self.tracked_elements = []
self.batches = [1, 2, 4, 8, 16, 32]
self.batch_groups = []
self.pipeline = pipeline.copy()
self.first_iteration = True

instance_ids = {}

for idx, element in enumerate(self.pipeline):
if "gvadetect" in element or "gvaclassify" in element:
(_, parameters) = parse_element_parameters(element)
instance_id = parameters.get("model-instance-id")
group_idx = 0

# if element has an instance id, get the batch group index
if instance_id:
Comment thread
oonyshch marked this conversation as resolved.
group_idx = instance_ids.get(instance_id)

# if this instance id is new, create a new group index
if group_idx is None:
group_idx = len(self.batch_groups)
self.batch_groups.append(0)
instance_ids[instance_id] = group_idx

# if there's no instance id, treat element as its own group
else:
group_idx = len(self.batch_groups)
self.batch_groups.append(0)


self.tracked_elements.append({
"index": idx,
"batch_idx": 0,
"instance_id": "inf" + str(idx)
"group_idx": group_idx,
})

def __iter__(self):
Expand All @@ -154,16 +176,15 @@ def __iter__(self):
def __next__(self) -> list:
# Prepare the next combination of batches
end_of_variants = True
for element in self.tracked_elements:
for idx, cur_batch_idx in enumerate(self.batch_groups):
# Don't change anything on first iteration
if self.first_iteration:
self.first_iteration = False
end_of_variants = False
break

cur_batch_idx = element["batch_idx"]
next_batch_idx = (cur_batch_idx + 1) % len(self.batches)
element["batch_idx"] = next_batch_idx
self.batch_groups[idx] = next_batch_idx

# Walk through elements while they still
# have more batch options
Expand All @@ -176,9 +197,9 @@ def __next__(self) -> list:
if end_of_variants:
raise StopIteration

# log device combinations
batches = self.tracked_elements.copy()
batches = list(map(lambda e: self.batches[e["batch_idx"]], batches)) # transform batch indices into batches
# log batch combinations
batches = self.batch_groups.copy()
batches = list(map(lambda e: self.batches[e], batches)) # transform batch indices into batches
logger.info("Testing batch combination: %s", str(batches))

# Prepare pipeline output
Expand All @@ -189,7 +210,7 @@ def __next__(self) -> list:
(element_type, parameters) = parse_element_parameters(pipeline[idx])

# Get the batch for this element
batch = self.batches[element["batch_idx"]]
batch = self.batches[self.batch_groups[element["group_idx"]]]

# Apply current configuration
parameters["batch-size"] = str(batch)
Expand All @@ -201,22 +222,44 @@ def __next__(self) -> list:
class NireqGenerator:
def __init__(self):
self.tracked_elements = []
self.nireqs = []
self.nireqs = range(1, 9)
self.nireq_groups = []
self.pipeline = []
self.first_iteration = True

def init_pipeline(self, pipeline):
self.tracked_elements = []
self.nireqs = range(1, 9)
self.nireq_groups = []
self.pipeline = pipeline.copy()
self.first_iteration = True

instance_ids = {}

for idx, element in enumerate(self.pipeline):
if "gvadetect" in element or "gvaclassify" in element:
(_, parameters) = parse_element_parameters(element)
instance_id = parameters.get("model-instance-id")
group_idx = 0

# if element has an instance id, get the nireq group index
if instance_id:
Comment thread
oonyshch marked this conversation as resolved.
group_idx = instance_ids.get(instance_id)

# if this instance id is new, create a new group index
if group_idx is None:
group_idx = len(self.nireq_groups)
self.nireq_groups.append(0)
instance_ids[instance_id] = group_idx

# if there's no instance id, treat element as its own group
else:
group_idx = len(self.nireq_groups)
self.nireq_groups.append(0)


self.tracked_elements.append({
"index": idx,
"nireq_idx": 0,
"instance_id": "inf" + str(idx)
"group_idx": group_idx,
})

def __iter__(self):
Expand All @@ -225,16 +268,15 @@ def __iter__(self):
def __next__(self) -> list:
# Prepare the next combination of nireqs
end_of_variants = True
for element in self.tracked_elements:
for idx, cur_nireq_idx in enumerate(self.nireq_groups):
# Don't change anything on first iteration
if self.first_iteration:
self.first_iteration = False
end_of_variants = False
break

cur_nireq_idx = element["nireq_idx"]
next_nireq_idx = (cur_nireq_idx + 1) % len(self.nireqs)
element["nireq_idx"] = next_nireq_idx
self.nireq_groups[idx] = next_nireq_idx

# Walk through elements while they still
# have more nireq options
Expand All @@ -247,9 +289,9 @@ def __next__(self) -> list:
if end_of_variants:
raise StopIteration

# log device combinations
nireqs = self.tracked_elements.copy()
nireqs = list(map(lambda e: self.nireqs[e["nireq_idx"]], nireqs)) # transform nireq indices into nireqs
# log nireq combinations
nireqs = self.nireq_groups.copy()
nireqs = list(map(lambda e: self.nireqs[e], nireqs)) # transform nireq indices into nireqs
logger.info("Testing nireq combination: %s", str(nireqs))

# Prepare pipeline output
Expand All @@ -260,7 +302,7 @@ def __next__(self) -> list:
(element_type, parameters) = parse_element_parameters(pipeline[idx])

# Get the nireq for this element
nireq = self.nireqs[element["nireq_idx"]]
nireq = self.nireqs[self.nireq_groups[element["group_idx"]]]

# Apply current configuration
parameters["nireq"] = str(nireq)
Expand Down
Loading