Skip to content

Bug fix for named nn.Sequential in pytorch parser #848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 28, 2023
20 changes: 15 additions & 5 deletions hls4ml/converters/pytorch/convolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format


Expand All @@ -9,11 +9,16 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c
layer = {}

layer['name'] = layer_name
layer['inputs'] = input_names
layer['class_name'] = 'Conv1D'
layer['data_format'] = 'channels_first' # Pytorch default (can't change)

layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight')
layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')
layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None

# Input info
(layer['in_width'], layer['n_chan']) = parse_data_format(
input_shapes[0], 'channels_first'
Expand Down Expand Up @@ -54,11 +59,16 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c
layer = {}

layer['name'] = layer_name
layer['inputs'] = input_names
layer['class_name'] = 'Conv2D'
layer['data_format'] = 'channels_first' # Pytorch default (can't change)

layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight')
layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')
layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None

# Input info
(layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(
input_shapes[0], 'channels_first'
Expand Down
24 changes: 16 additions & 8 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler
from hls4ml.converters.pytorch_to_hls import pytorch_handler


@pytorch_handler('Linear')
Expand All @@ -9,8 +9,14 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c

layer['class_name'] = 'Dense'
layer['name'] = layer_name
layer['inputs'] = input_names

layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None

layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['weight', 'bias'])
if class_object is not None:
layer['n_in'] = class_object.in_features
layer['n_out'] = class_object.out_features
Expand Down Expand Up @@ -39,6 +45,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
layer['class_name'] = operation
layer['activation'] = layer['class_name']
layer['name'] = layer_name
layer['inputs'] = input_names

# if layer['class_name'] != 'Activation':
# layer['activation'] = layer['class_name']
Expand All @@ -50,7 +57,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
if layer['class_name'] == 'ELU':
layer['activ_param'] = class_object.alpha
if layer['class_name'] == 'PReLU':
layer['alpha_data'] = get_weights_data(data_reader, layer['name'], 'weight')
layer['alpha_data'] = class_object.weight.data.numpy()
if layer['class_name'] == 'Threshold':
layer['activ_param'] = class_object.threshold
layer['class_name'] = 'ThresholdedReLU'
Expand Down Expand Up @@ -92,25 +99,26 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node
layer['class_name'] = 'BatchNormalization'
layer['data_format'] = 'channels_first'
layer['name'] = layer_name
layer['inputs'] = input_names

# batchnorm para
if node.op == 'call_module':
layer['epsilon'] = class_object.eps
layer['use_gamma'] = layer['use_beta'] = class_object.affine

if layer['use_gamma']:
layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'weight')
layer['gamma_data'] = class_object.weight.data.numpy()
else:
layer['gamma_data'] = 1

if layer['use_beta']:
layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'bias')
layer['beta_data'] = class_object.bias.data.numpy()
else:
layer['beta_data'] = 0

layer['mean_data'], layer['variance_data'] = get_weights_data(
data_reader, layer['name'], ['running_mean', 'running_var']
)
layer['mean_data'] = class_object.running_mean.data.numpy()
layer['variance_data'] = class_object.running_var.data.numpy()

in_size = 1
for dim in input_shapes[0][1:]:
in_size *= dim
Expand Down
1 change: 1 addition & 0 deletions hls4ml/converters/pytorch/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node,
layer['class_name'] = 'AveragePooling2D'

layer['name'] = layer_name
layer['inputs'] = input_names
layer['data_format'] = 'channels_first' # Pytorch default (can't change)
if node.op == 'call_module' and 'Avg' in operation:
if class_object.count_include_pad:
Expand Down
55 changes: 55 additions & 0 deletions hls4ml/converters/pytorch/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node,
layer = {}
layer['class_name'] = 'Reshape'
layer['name'] = layer_name
layer['inputs'] = input_names

layer['target_shape'] = [int(i) for i in node.args[1:]]
# View can have -1 as one as the dimensions,
Expand All @@ -29,6 +30,60 @@ def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node,
return layer, output_shape


@pytorch_handler('squeeze')
def parse_squeeze_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'squeeze'

layer = {}
layer['class_name'] = 'Reshape'
layer['name'] = layer_name

if len(node.args) > 1 or len(node.kwargs) > 0: # 'dim' argument is specified
output_shape = [i for i in input_shapes[0]]
squeeze_dim = node.kwargs.get('dim', None)
if squeeze_dim is None:
squeeze_dim = node.args[1]
if isinstance(squeeze_dim, tuple):
for dim in squeeze_dim:
del output_shape[dim]
else:
del output_shape[squeeze_dim]
else:
output_shape = [i for i in input_shapes[0] if i != 1]

layer['target_shape'] = output_shape.copy()
if layer['target_shape'][0] is None:
del layer['target_shape'][0]

return layer, output_shape


@pytorch_handler('unsqueeze')
def parse_unsqueeze_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'unsqueeze'

layer = {}
layer['class_name'] = 'Reshape'
layer['name'] = layer_name
layer['inputs'] = input_names

# Unlike in 'squeeze' in 'unsqueeze', dim argument must exist
output_shape = [i for i in input_shapes[0]]
if len(node.args) > 1: # Specified as unsqueeze(x, n)
squeeze_dim = node.args[1]
else: # Specified as unsqueeze(x, dim=n)
squeeze_dim = node.kwargs['dim']
# insert() will add an element before the index, unsqueeze expects the location
index = output_shape.index(output_shape[squeeze_dim]) # + 1
output_shape.insert(index, 1)

layer['target_shape'] = output_shape.copy()
if layer['target_shape'][0] is None:
del layer['target_shape'][0]

return layer, output_shape


@pytorch_handler('Flatten')
def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'Flatten'
Expand Down
50 changes: 23 additions & 27 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,12 @@ def __init__(self, config):
def get_weights_data(self, layer_name, var_name):
data = None

# Workaround for naming schme in nn.Sequential,
# have to remove the prefix we previously had to add to make sure the tensors are found
if 'layer_' in layer_name:
layer_name = layer_name.split('layer_')[-1]
tensorName = layer_name + '.' + var_name

# if a layer is reused in the model, torch.FX will append a "_n" for the n-th use
# have to snap that off to find the tensors
if layer_name.split('_')[-1].isdigit() and len(layer_name.split('_')) > 1:
layer_name = '_'.join(layer_name.split('_')[:-1])
if tensorName in self.state_dict:
data = self.state_dict[tensorName].numpy()

if layer_name + '.' + var_name in self.state_dict:
data = self.state_dict[layer_name + '.' + var_name].numpy()
return data

else:
return None
return data


class PyTorchFileReader(PyTorchModelReader): # Inherit get_weights_data method
Expand Down Expand Up @@ -98,7 +88,7 @@ def decorator(function):
'elu': 'ELU',
'prelu': 'PReLU',
'sigmoid': 'Sigmoid',
'layer_threshold': 'Threshold',
'_threshold': 'Threshold',
'softmax': 'Softmax',
'max_pool1d': 'MaxPool1d',
'max_pool2d': 'MaxPool2d',
Expand Down Expand Up @@ -134,6 +124,7 @@ def pytorch_to_hls(config):
input_shapes = [list(reader.input_shape)]
else:
input_shapes = list(reader.input_shape)
input_shapes = [list(shape) for shape in input_shapes]

model = reader.torch_model

Expand All @@ -149,6 +140,9 @@ def pytorch_to_hls(config):
# All supported layers
supported_layers = get_supported_pytorch_layers() + skip_layers

# Map inputs of skipped and split (activation) layers
inputs_map = {}

input_layers = []

# Output shape tracking
Expand All @@ -162,15 +156,16 @@ def pytorch_to_hls(config):
n_inputs = 0

for node in traced_model.graph.nodes:
# If part of a nn.Sequntial, the node name will start with an "_" which messes up the parsing
if node.name[0] == '_':
node.name = 'layer' + node.name

if node.op == 'call_module':
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
# where x is an integer numbering the elements of the Sequential
if '.' in node.target:
class_object = children[node.target.split('.')[0]][int(node.target.split('.')[1])]
fqn_path = node.target.split('.')
sub_children = dict(children[fqn_path[0]].named_children())
for name in fqn_path[1:-1]:
sub_children = dict(sub_children[name].named_children())
sub_children[fqn_path[-1]]
class_object = sub_children[fqn_path[-1]]
else:
class_object = children[node.target]

Expand All @@ -189,6 +184,10 @@ def pytorch_to_hls(config):
if pytorch_class == 'Sequential': # Ignore the mother module's class name
continue

# Assuming only one input
parent_input = [str(i) for i in node.args][0]
inputs_map[layer_name] = inputs_map.get(parent_input, parent_input)

output_shapes[layer_name] = input_shapes[0]

continue
Expand All @@ -198,7 +197,7 @@ def pytorch_to_hls(config):
layer_counter += 1

# parse info from class object
input_names = [str(i) for i in node.args]
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
input_shapes = [output_shapes[str(i)] for i in node.args]

# for Conv layers
Expand Down Expand Up @@ -236,7 +235,7 @@ def pytorch_to_hls(config):
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = input_shapes[n_inputs]
output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
input_layers.append(input_layer['name'])
n_inputs += 1

Expand Down Expand Up @@ -265,7 +264,7 @@ def pytorch_to_hls(config):

layer_counter += 1

input_names = [str(i) for i in node.all_input_nodes]
input_names = [inputs_map.get(str(i), str(i)) for i in node.all_input_nodes]
input_shapes = [list(output_shapes[str(i)]) for i in input_names]

# Process the layer
Expand Down Expand Up @@ -318,10 +317,7 @@ def pytorch_to_hls(config):

layer_counter += 1

if 'View' in operation:
input_names = [str(node.args[0])]
else:
input_names = [str(i) for i in node.args]
input_names = [inputs_map.get(str(i), str(i)) for i in node.all_input_nodes]

# Process the layer
input_shapes = [list(output_shapes[str(i)]) for i in input_names]
Expand Down
Loading