From 4304984a0702706e820992d021f521a0b2f9a378 Mon Sep 17 00:00:00 2001 From: abdelabd <51892588+abdelabd@users.noreply.github.com> Date: Sun, 13 Jun 2021 13:47:02 -0400 Subject: [PATCH 01/13] pyg_to_hls Add files via upload Update hls_layers.py Update hls_layers.py Update hls_model.py Update hls_model.py Update vivado_writer.py Add files via upload Update vivado_template.py Update vivado_writer.py added vec_to_mat, mat_to_vec Delete nnet_graph.h Add files via upload handling for 1D-vector inputs and outputs Update vivado_writer.py Update pyg_to_hls.py Update hls_layers.py Update vivado_template.py Update pyg_to_hls.py Update hls_layers.py Update nnet_activation.h Update nnet_dense_resource.h wip Add files via upload Add files via upload added testbench handling for concatenation added weights tranpose, deleted hard-coded testbench add nnet::matrix_config cleaning up add HLSModel_GNN._get_top_function(), HLSModel_GNN.predict(..) top-level conversion function (starting ground) update naming conventions:'Rn' --> 'node_attr', 'Re' --> 'edge_attr'. update order of inputs: 1.node_attr, 2.edge_attr, 3.edge_index update naming conventions, edge_index shape (now, edge_index.shape=[N_EDGE, 2]) add max, mean aggregation. update naming conventions generalize all the aggregation methods in a single function cleaning up added precision and testbenching added precision handle added handling for different flow direction cleaning up from 'save intermediates' slight improvement on max aggregation fixed max-aggregation minor updates generalizing pyg_to_hls() tidying up improved handling for user-input precision re-included extra #pragma HLS UNROLL, don't know if this is correct yet implemented LUT-division missed a semicolon added handling for initial aggregation layer added Aggregate layer, self._check_inputs() added Aggregate layer use existing test bench code update naming conventions tidying up, added #pragma HLS UNROLL to nnet_array::vec_to_mat/mat_to_vec ditched single-edge aggregation functions re-added single-edge-aggregation functions re-added #pragma HLS ARRAY_PARTITION speciy model inputs; parition merge cleaning up, update naming conventions ditched single-edge-aggregation functions, improved sender-index/receiver-index handling changed 'else if{' to 'else{ if{' split different aggregation methods into separate functions max not fully functional yet, just committing changes before switching branch got max-aggregation LOGIC working, still testing build_prj.tcl fixed up max-aggregation again --- hls4ml/converters/__init__.py | 29 + hls4ml/converters/pyg_to_hls.py | 191 ++++++ hls4ml/model/hls_layers.py | 592 +++++++++++++++++- hls4ml/model/hls_model.py | 7 +- .../templates/vivado/nnet_utils/nnet_array.h | 30 + .../templates/vivado/nnet_utils/nnet_dense.h | 1 + .../vivado/nnet_utils/nnet_dense_resource.h | 12 +- .../templates/vivado/nnet_utils/nnet_graph.h | 439 +++++++++++++ hls4ml/templates/vivado_template.py | 66 ++ hls4ml/utils/config.py | 14 +- hls4ml/writer/vivado_writer.py | 13 +- 11 files changed, 1385 insertions(+), 9 deletions(-) create mode 100644 hls4ml/converters/pyg_to_hls.py create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_graph.h diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 0b132a209a..ac066baef7 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -11,6 +11,7 @@ #----------Make converters available if the libraries can be imported----------# try: from hls4ml.converters.pytorch_to_hls import pytorch_to_hls, get_supported_pytorch_layers, register_pytorch_layer_handler + from hls4ml.converters.pyg_to_hls import pyg_to_hls __pytorch_enabled__ = True except ImportError: warnings.warn("WARNING: Pytorch converter is not enabled!") @@ -267,6 +268,34 @@ def convert_from_pytorch_model(model, input_shape, output_dir='my-hls-test', pro return pytorch_to_hls(config) +def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim, + forward_dictionary=None, activate_final=None, + output_dir='my-hls-test', project_name='myproject', + fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}): + + config = create_vivado_config( + output_dir=output_dir, + project_name=project_name, + fpga_part=fpga_part, + clock_period=clock_period, + io_type=io_type + ) + + config['PytorchModel'] = model + config['InputShape'] = { + 'NodeAttr': [n_node, node_dim], + 'EdgeAttr': [n_edge, edge_dim], + 'EdgeIndex': [n_edge, 2] + } + config['ForwardDictionary'] = forward_dictionary + config['ActivateFinal'] = activate_final + + model_config = hls_config.get('Model', None) + config['HLSConfig']['Model'] = _check_model_config(model_config) + + _check_hls_config(config, hls_config) + + return pyg_to_hls(config) def convert_from_onnx_model(model, output_dir='my-hls-test', project_name='myproject', fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}): diff --git a/hls4ml/converters/pyg_to_hls.py b/hls4ml/converters/pyg_to_hls.py new file mode 100644 index 0000000000..34019cd1c3 --- /dev/null +++ b/hls4ml/converters/pyg_to_hls.py @@ -0,0 +1,191 @@ +from __future__ import print_function +import torch + +from hls4ml.converters.pytorch_to_hls import PyTorchModelReader +from hls4ml.model.hls_model import HLSModel +from hls4ml.templates import get_backend + + +class PygModelReader(PyTorchModelReader): + def __init__(self, config): + super().__init__(config) + self.n_node = config['InputShape']['NodeAttr'][0] + self.n_edge = config['InputShape']['EdgeAttr'][0] + self.node_dim = config['InputShape']['NodeAttr'][1] + self.edge_dim = config['InputShape']['EdgeAttr'][1] + + def get_weights_data(self, layer_name, var_name, module_name=None): + data = None + + # Parameter mapping from pytorch to keras + torch_paramap = { + # Conv + 'kernel': 'weight', + # Batchnorm + 'gamma': 'weight', + 'beta': 'bias', + 'moving_mean': 'running_mean', + 'moving_variance': 'running_var'} + + if var_name not in list(torch_paramap.keys()) + ['weight', 'bias']: + raise Exception('Pytorch parameter not yet supported!') + + if module_name is not None: + if var_name in list(torch_paramap.keys()): + var_name = torch_paramap[var_name] + + try: + data = self.state_dict[module_name + '.' + layer_name + '.' + var_name].numpy().transpose() + except KeyError: + data = self.state_dict[module_name + '.layers.' + layer_name + '.' + var_name].numpy().transpose() + + else: + if var_name in list(torch_paramap.keys()): + var_name = torch_paramap[var_name] + + data = self.state_dict[layer_name + '.' + var_name].numpy().transpose() # Look at transpose when systhesis produce lousy results. Might need to remove it. + + return data + +def pyg_to_hls(config): + + forward_dict = config['ForwardDictionary'] + activate_final = config['ActivateFinal'] + + # get precisions + backend = get_backend(config.get('Backend', 'Vivado')) + fp_type = backend.convert_precision_string(config['HLSConfig']['Model']['Precision']) + int_type = backend.convert_precision_string(config['HLSConfig']['Model']['IndexPrecision']) + + # make reader + reader = PygModelReader(config) + n_node = reader.n_node + n_edge = reader.n_edge + node_dim = reader.node_dim + edge_dim = reader.edge_dim + + + # initiate layer list with inputs: node_attr, edge_attr, edge_index + layer_list = [] + input_shapes = reader.input_shape + NodeAttr_layer = { + 'name': 'node_attr', + 'class_name': 'InputLayer', + 'input_shape': input_shapes['NodeAttr'], + 'inputs': 'input', + 'dim_names': ['N_NODE', 'NODE_DIM'], + 'precision': fp_type + } + layer_list.append(NodeAttr_layer) + EdgeAttr_layer = { + 'name': 'edge_attr', + 'class_name': 'InputLayer', + 'input_shape': input_shapes['EdgeAttr'], + 'inputs': 'input', + 'dim_names': ['N_EDGE', 'EDGE_DIM'], + 'precision': fp_type + } + layer_list.append(EdgeAttr_layer) + EdgeIndex_layer = { + 'name': 'edge_index', + 'class_name': 'InputLayer', + 'input_shape': input_shapes['EdgeIndex'], + 'inputs': 'input', + 'dim_names': ['N_EDGE', 'TWO'], + 'precision': int_type + } + layer_list.append(EdgeIndex_layer) + last_node_update = "node_attr" + last_edge_update = "edge_attr" + + # If the first block is a NodeBlock, we need a layer to construct the initial edge_aggregates + if forward_dict[list(forward_dict.keys())[0]] == "NodeBlock": + aggr_layer = {"name": "aggr1", + "class_name": "Aggregate", + "n_node": n_node, + "n_edge": n_edge, + "node_dim": node_dim, + "edge_dim": edge_dim, + "precision": fp_type, + "out_dim": edge_dim, + "inputs": ["edge_attr", "edge_index"], + "outputs": ["edge_attr_aggr"]} + layer_list.append(aggr_layer) + last_edge_aggr_update = "edge_attr_aggr" + else: last_edge_aggr_update = None + + # complete the layer list + for i, (key, val) in enumerate(forward_dict.items()): + layer_dict = { + "name": key, + "class_name": val, + "n_node": n_node, + "n_edge": n_edge, + "node_dim": node_dim, + "edge_dim": edge_dim, + "precision": fp_type + } + + # get n_layers, out_dim + model = config['PytorchModel'] + torch_block = getattr(model, key) + try: + torch_layers = torch_block.layers._modules + except AttributeError: + torch_layers = torch_block._modules + + lcount = 0 + for lname, l in torch_layers.items(): + if isinstance(l, torch.nn.modules.linear.Linear): + lcount += 1 + last_layer = l + layer_dict["n_layers"] = lcount + layer_dict["out_dim"] = last_layer.out_features + + # get inputs, outputs + if val == "NodeBlock": + index = len(layer_list) + 1 + layer_dict["inputs"] = [last_node_update, last_edge_aggr_update] + layer_dict["outputs"] = [f"layer{index}_out"] + last_node_update = f"layer{index}_out" + layer_list.append(layer_dict) + elif val == "EdgeBlock": + index = len(layer_list) + 1 + layer_dict["inputs"] = [last_node_update, last_edge_update, "edge_index"] + layer_dict["outputs"] = [f"layer{index}_out"] + last_edge_update = f"layer{index}_out" + layer_list.append(layer_dict) + + # if val==EdgeBlock and this is not the final graph-block, follow it with an aggregation layer + if (val == "EdgeBlock") and (i < len(forward_dict) - 1): + index = len(layer_list) + 1 + layer_dict = {"name": f"aggr{index}", + "class_name": "Aggregate", + "n_node": n_node, + "n_edge": n_edge, + "node_dim": node_dim, + "edge_dim": edge_dim, + "precision": fp_type, + "out_dim": edge_dim, + "inputs": [last_edge_update, "edge_index"], + "outputs": [f"layer{index}_out"]} + last_edge_aggr_update = f"layer{index}_out" + layer_list.append(layer_dict) + + if activate_final is not None: + act_dict = { + 'name': 'final_act', + 'class_name': 'Activation', + 'inputs': [f"layer{len(layer_list)}_out"], + 'activation': activate_final, + 'precision': fp_type + } + layer_list.append(act_dict) + out = ["final_act"] + else: + out = [layer_list[-1]['outputs'][0]] + + hls_model = HLSModel(config, reader, layer_list, inputs=['node_attr', 'edge_attr', 'edge_index']) + hls_model.outputs = out + return hls_model + diff --git a/hls4ml/model/hls_layers.py b/hls4ml/model/hls_layers.py index 3ad2edff6b..8190826712 100644 --- a/hls4ml/model/hls_layers.py +++ b/hls4ml/model/hls_layers.py @@ -560,7 +560,10 @@ def initialize(self): shape = self.attributes['input_shape'] if shape[0] is None: shape = shape[1:] - dims = ['N_INPUT_{}_{}'.format(i, self.index) for i in range(1, len(shape) + 1)] + try: + dims = self.attributes['dim_names'] + except KeyError: + dims = ['N_INPUT_{}_{}'.format(i, self.index) for i in range(1, len(shape) + 1)] if self.index == 1: default_type_name = 'input_t' else: @@ -642,6 +645,10 @@ def config_cpp(self): params['nonzeros'] = self.get_weights('weight').nonzeros params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision) params['strategy'] = self.get_attr('strategy') + if self.get_attr('remove_pipeline_pragma') is not None: + params['remove_pipeline_pragma'] = self.get_attr('remove_pipeline_pragma') + else: + params['remove_pipeline_pragma'] = "false" return self._config_template.format(**params) @@ -1831,6 +1838,586 @@ def _get_transforms_config(self, params): params['sublayer_configs'] = '\n'.join(sublayer_configs) +class GraphBlock(Layer): #parent class for EdgeBlock, NodeBlock + def add_weights(self, quantizer=None, compression=False): + linear_count = 0 + + for name, module in self.submodules.items(): + if module.__class__.__name__ == 'Linear': + data = self.model.get_weights_data(name, 'kernel', self.name).transpose() + var_name = f"{self.name}_w{linear_count}" + self.add_weights_variable(name=var_name, var_name=var_name, data=data, quantizer=quantizer, + compression=compression) + linear_count += 1 + + # DUMMIES + if linear_count <= 3: + for i in range(linear_count, 4): + self.add_weights_variable(name=f"{self.name}_w{i}", var_name=f"{self.name}_w{i}", data=data, + quantizer=quantizer, compression=compression) + + def add_bias(self, quantizer=None): + precision = None + type_name = None + linear_count = 0 + + for name, module in self.submodules.items(): + if module.__class__.__name__ == 'Linear': + data = self.model.get_weights_data(name, 'bias', self.name) + var_name = f"{self.name}_b{linear_count}" + self.add_weights_variable(name=var_name, var_name=var_name, type_name=type_name, precision=precision, + data=data, quantizer=quantizer) + linear_count += 1 + + # DUMMIES + if linear_count <= 3: + for i in range(linear_count, 4): + self.add_weights_variable(name=f"{self.name}_b{i}", var_name=f"{self.name}_b{i}", type_name=type_name, + precision=precision, data=data, quantizer=quantizer) + + def get_dense_params(self, dense_layer, linear_count): # hard-coded for now + params = {} + params['type'] = 'dense' + params['index'] = linear_count + params['n_in'] = dense_layer.in_features + params['n_out'] = dense_layer.out_features + params['iotype'] = 'io_parallel' + params['reuse'] = 1 + params['nzeros'] = 0 + params['remove_pipeline_pragma'] = 'true' + + params['accum_t'] = f'layer{self.index}_t' + params['bias_t'] = f'layer{self.index}_t' + params['weight_t'] = f'layer{self.index}_t' + + return params + + def get_relu_params(self, relu_count, last_n_out): + params = {} + params['type'] = 'relu' + params['index'] = relu_count + params['n_in'] = last_n_out + params['table_size'] = 1024 + params['iotype'] = 'io_parallel' + return params + + def config_layer(self, layer_type, layer_params): + all_lines = self.model.config.backend.get_config_template(layer_type).split('\n') + all_lines[0] = re.sub('struct config{index}', 'struct {type}_config{index}', all_lines[0]) + param_lines = [] + out = [] + + for param in layer_params: + p_lines = [i for i in all_lines if "{%s}" % param in i] + if len(p_lines) == 1 and p_lines[0] not in param_lines: + param_lines.append(p_lines[0]) + elif len(p_lines) < 1: + print(f"param {param} not found in {layer_type} config template") + else: + pass + + for line in all_lines: + if line in param_lines: + out.append(line) + else: + param_search = line.find('{') + if param_search == -1: + if 'template' not in line: + out.append(line) + + out = '\n'.join(out) + out = out.format(**layer_params) + return out + + def _config_sublayers(self): + linear_count = 0 + relu_count = 0 + configs = OrderedDict() + + for name, module in self.submodules.items(): + if module.__class__.__name__==self.model.reader.torch_model.__class__.__name__: + continue + + if module.__class__.__name__ == "Linear": + linear_count += 1 + linear_params = self.get_dense_params(module, linear_count) + linear_config = self.config_layer('Dense', linear_params) + configs[f"dense_config{linear_count}"] = linear_config + last_n_out = linear_params['n_out'] + + elif module.__class__.__name__ == "ReLU": + relu_count += 1 + relu_params = self.get_relu_params(relu_count, last_n_out) + relu_config = self.config_layer('Activation', relu_params) + configs[f"relu_config{relu_count}"] = relu_config + last_n_out = relu_params['n_in'] + + # DUMMIES + if linear_count < 4: + for i in range(linear_count + 1, 5): + linear_config_i = linear_config.split('\n') + linear_config_i[0] = re.sub(f"dense_config{linear_count}", f"dense_config{i}", linear_config_i[0]) + linear_config_i = "\n".join(linear_config_i) + configs[f"dense_config{i}"] = linear_config_i + + if relu_count < 4: + for i in range(relu_count + 1, 5): + relu_config_i = relu_config.split('\n') + relu_config_i[0] = re.sub(f"relu_config{relu_count}", f"relu_config{i}", relu_config_i[0]) + relu_config_i = '\n'.join(relu_config_i) + configs[f"relu_config{i}"] = relu_config_i + + return configs + +class EdgeBlock(GraphBlock): + def initialize(self): + self.n_node = self.attributes['n_node'] + self.n_edge = self.attributes['n_edge'] + self.node_dim = self.attributes['node_dim'] + self.edge_dim = self.attributes['edge_dim'] + self.out_dim = self.attributes['out_dim'] + self._check_inputs() + + self.n_edge_cppname, self.edge_dim_cppname = self.model.get_layer_output_variable('edge_attr').dim_names + self.n_node_cppname, self.node_dim_cppname = self.model.get_layer_output_variable('node_attr').dim_names + self.out_dim_cppname = f"LAYER{self.index}_OUT_DIM" + + self.torch_module = getattr(self.model.reader.torch_model, self.name) + submodules = OrderedDict() + try: + for name, module in self.torch_module.layers.named_modules(): + submodules[name] = module + except AttributeError: + for name, module in self.torch_module.named_modules(): + submodules[name] = module + self.submodules = submodules + + # edge predictions + out_shape = [self.n_edge, self.out_dim] + out_dims = [self.n_edge_cppname, self.out_dim_cppname] + out_name = f"layer{self.index}_out" + self.add_output_variable(shape=out_shape, dim_names=out_dims, out_name=out_name, var_name=out_name, precision=self.attributes.get('precision', None), pragma='partition') + + self.add_weights(quantizer=self.get_attr('weight_quantizer'), + compression=self.model.config.get_compression(self)) + self.add_bias(quantizer=self.get_attr('weight_quantizer')) + + # Reshape the input/output variables + #for input_name in self.inputs: + # input_array = self.get_input_variable(input_name) + # partition_factor = input_array.shape[0] + # if input_name in self.model.inputs: + # input_array.pragma = ('reshape', 'block', partition_factor) + # else: + # input_array.pragma = ('partition', 'block', partition_factor) + + #for output_name in self.outputs: + # output_array = self.get_output_variable(output_name) + # partition_factor = output_array.shape[0] + # output_array.pragma = ('partition', 'block', partition_factor) + + def function_cpp(self): + params = {} + params['config'] = 'config{}'.format(self.index) + params['input_t'] = self.model.get_layer_output_variable('edge_attr').type.name + params['index_t'] = self.model.get_layer_output_variable('edge_index').type.name + params['output_t'] = self.get_output_variable().type.name + params['node_attr'] = self.attributes['inputs'][0] + params['edge_attr'] = self.attributes['inputs'][1] + params['edge_index'] = self.attributes['inputs'][2] + params['out'] = f"layer{self.index}_out" + + params['w0'] = self.get_weights(f"{self.name}_w0").name + params['b0'] = self.get_weights(f"{self.name}_b0").name + params['w1'] = self.get_weights(f"{self.name}_w1").name + params['b1'] = self.get_weights(f"{self.name}_b1").name + params['w2'] = self.get_weights(f"{self.name}_w2").name + params['b2'] = self.get_weights(f"{self.name}_b2").name + params['w3'] = self.get_weights(f"{self.name}_w3").name + params['b3'] = self.get_weights(f"{self.name}_b3").name + + out = self._function_template.format(**params) + return [out] + + def config_cpp(self): + top_params = self.get_EdgeBlock_params() + top_config = self._config_template.format(**top_params) + top_config = top_config.split('\n')[:-1] + top_config = '\n'.join(top_config) + + sublayer_configs = self._config_sublayers() + sublayer_configs.update(self._config_misc()) + for layer, config in sublayer_configs.items(): + config = [' ' + i for i in config.split('\n')] + config = '\n'.join(config) + + top_config += '\n\n' + top_config += config + + top_config += '\n};' + return top_config + + def get_EdgeBlock_params(self): # hard-coded for now + params = {} + params['index'] = self.index + params['bias_t'] = f'layer{self.index}_t' + params['weight_t'] = f'layer{self.index}_t' + params['table_t'] = f'layer{self.index}_t' + params['n_node'] = self.n_node_cppname + params['n_edge'] = self.n_edge_cppname + params['node_dim'] = self.node_dim_cppname + params['edge_dim'] = self.edge_dim_cppname + params['out_dim'] = self.out_dim + params['n_layers'] = self.attributes["n_layers"] + params['io_type'] = 'io_parallel' + params['reuse'] = self.reuse_factor + params['n_zeros'] = 0 + + flow_map = { + "source_to_target": 0, + "target_to_source": 1 + } + params["flow"] = flow_map[self.attributes.get("flow", self.model.reader.torch_model.flow)] + + return params + + def _config_misc(self): + configs = OrderedDict() + + # matrix configs + matrix_config_template = """struct {matrix_name}_config: nnet::matrix_config{{ + static const unsigned n_rows = {n_rows}; + static const unsigned n_cols = {n_cols}; + }};""" + + configs['node_attr_config'] = matrix_config_template.format(matrix_name="node_attr", + n_rows=self.n_node_cppname, + n_cols=self.node_dim_cppname) + + configs['edge_attr_config'] = matrix_config_template.format(matrix_name="edge_attr", + n_rows=self.n_edge_cppname, + n_cols=self.edge_dim_cppname) + + configs['edge_index_config'] = matrix_config_template.format(matrix_name="edge_index", + n_rows=self.n_edge_cppname, + n_cols="TWO") + + configs['edge_update_config'] = matrix_config_template.format(matrix_name="edge_update", + n_rows=self.n_edge_cppname, + n_cols=f"LAYER{self.index}_OUT_DIM") + + # concatenation configs + concat_config_template = self.model.config.backend.get_config_template('Concatenate') + concat_config_template = re.sub('config{index}', 'merge_config{index}', concat_config_template) + + merge_config1_params = { + 'index': 1, + 'n_elem1_0': self.node_dim_cppname, + 'n_elem1_1': 1, + 'n_elem1_2': 0, + 'n_elem2_0': self.node_dim_cppname, + 'n_elem2_1': 1, + 'n_elem2_2': 0, + 'axis': 0 + } + merge_config1 = concat_config_template.format(**merge_config1_params) + configs['merge_config1'] = merge_config1 + + merge_config2_params = { + 'index': 2, + 'n_elem1_0': f"2*{self.node_dim_cppname}", + 'n_elem1_1': 1, + 'n_elem1_2': 0, + 'n_elem2_0': self.edge_dim_cppname, + 'n_elem2_1': 1, + 'n_elem2_2': 0, + 'axis': 0 + } + merge_config2 = concat_config_template.format(**merge_config2_params) + configs['merge_config2'] = merge_config2 + + return configs + + def _check_inputs(self): + #expected inputs: node_attr, edge_attr, edge_index + assert (len(self.inputs) == 3) + + node_attr = self.model.get_layer_output_variable(self.inputs[0]) + assert(node_attr.shape==[self.n_node, self.node_dim]) + + edge_attr = self.model.get_layer_output_variable(self.inputs[1]) + assert(edge_attr.shape==[self.n_edge, self.edge_dim]) + + edge_index = self.model.get_layer_output_variable(self.inputs[2]) + assert(edge_index.shape==[self.n_edge, 2]) + + #expected outputs: edge_update, edge_update_aggr + assert (len(self.outputs) == 1) + +class NodeBlock(GraphBlock): + def initialize(self): + self.n_node = self.attributes['n_node'] + self.n_edge = self.attributes['n_edge'] + self.node_dim = self.attributes['node_dim'] + self.edge_dim = self.attributes['edge_dim'] + self.out_dim = self.attributes['out_dim'] + self._check_inputs() + + self.n_edge_cppname, self.edge_dim_cppname = self.model.get_layer_output_variable('edge_attr').dim_names + self.n_node_cppname, self.node_dim_cppname = self.model.get_layer_output_variable('node_attr').dim_names + self.out_dim_cppname = f"LAYER{self.index}_OUT_DIM" + + self.torch_module = getattr(self.model.reader.torch_model, self.name) + submodules = OrderedDict() + try: + for name, module in self.torch_module.layers.named_modules(): + submodules[name] = module + except AttributeError: + for name, module in self.torch_module.named_modules(): + submodules[name] = module + self.submodules = submodules + + # node predictions + out_shape = [self.n_node, self.out_dim] + out_dims = [self.n_node_cppname, self.out_dim_cppname] + out_name = f"layer{self.index}_out" + self.add_output_variable(shape=out_shape, dim_names=out_dims, out_name=out_name, var_name=out_name, precision=self.attributes.get('precision', None), pragma='partition') + + self.add_weights(quantizer=self.get_attr('weight_quantizer'), + compression=self.model.config.get_compression(self)) + self.add_bias(quantizer=self.get_attr('weight_quantizer')) + + # Reshape the input/output variables + #for input_name in self.inputs: + # input_array = self.get_input_variable(input_name) + # partition_factor = input_array.shape[0] + # if input_name in self.model.inputs: + # input_array.pragma = ('reshape', 'block', partition_factor) + # else: + # input_array.pragma = ('partition', 'block', partition_factor) + + #for output_name in self.outputs: + # output_array = self.get_output_variable(output_name) + # partition_factor = output_array.shape[0] + # output_array.pragma = ('partition', 'block', partition_factor) + + def function_cpp(self): + params = {} + params['config'] = 'config{}'.format(self.index) + params['input_t'] = self.model.get_layer_output_variable('node_attr').type.name + params['output_t'] = self.get_output_variable().type.name + params['node_attr'] = self.attributes["inputs"][0] + params['edge_attr_aggr'] = self.attributes["inputs"][1] + params['out'] = f"layer{self.index}_out" + + params['w0'] = self.get_weights(f"{self.name}_w0").name + params['b0'] = self.get_weights(f"{self.name}_b0").name + params['w1'] = self.get_weights(f"{self.name}_w1").name + params['b1'] = self.get_weights(f"{self.name}_b1").name + params['w2'] = self.get_weights(f"{self.name}_w2").name + params['b2'] = self.get_weights(f"{self.name}_b2").name + params['w3'] = self.get_weights(f"{self.name}_w3").name + params['b3'] = self.get_weights(f"{self.name}_b3").name + + out = self._function_template.format(**params) + return [out] + + def config_cpp(self): + top_params = self.get_NodeBlock_params() + top_config = self._config_template.format(**top_params) + top_config = top_config.split('\n')[:-1] + top_config = '\n'.join(top_config) + + sublayer_configs = self._config_sublayers() + sublayer_configs.update(self._config_misc()) + for layer, config in sublayer_configs.items(): + config = [' ' + i for i in config.split('\n')] + config = '\n'.join(config) + + top_config += '\n\n' + top_config += config + + top_config += '\n};' + return top_config + + def get_NodeBlock_params(self): # hard-coded for now + params = {} + params['index'] = self.index + params['bias_t'] = f'layer{self.index}_t' + params['weight_t'] = f'layer{self.index}_t' + params['table_t'] = f'layer{self.index}_t' + params['n_node'] = self.n_node_cppname + params['n_edge'] = self.n_edge_cppname + params['node_dim'] = self.node_dim_cppname + params['edge_dim'] = self.edge_dim_cppname + params['out_dim'] = self.out_dim + params['n_layers'] = self.attributes["n_layers"] + params['io_type'] = 'io_parallel' + params['reuse'] = self.reuse_factor + params['n_zeros'] = 0 + return params + + def _config_misc(self): + configs = OrderedDict() + + # matrix configs + matrix_config_template = """struct {matrix_name}_config: nnet::matrix_config{{ + static const unsigned n_rows = {n_rows}; + static const unsigned n_cols = {n_cols}; + }};""" + + configs['node_attr_config'] = matrix_config_template.format(matrix_name="node_attr", + n_rows=self.n_node_cppname, + n_cols=self.node_dim_cppname) + + configs['edge_attr_aggr_config'] = matrix_config_template.format(matrix_name="edge_attr_aggr", + n_rows=self.n_node_cppname, + n_cols=f"LAYER{self.index - 1}_OUT_DIM") + + configs['node_update_config'] = matrix_config_template.format(matrix_name="node_update", + n_rows=self.n_node_cppname, + n_cols=f"LAYER{self.index}_OUT_DIM") + + # concatenation configs + concat_config_template = self.model.config.backend.get_config_template('Concatenate') + concat_config_template = re.sub('config{index}', 'merge_config{index}', concat_config_template) + merge_config1_params = { + 'index': 1, + 'n_elem1_0': self.node_dim_cppname, + 'n_elem1_1': 1, + 'n_elem1_2': 0, + 'n_elem2_0': self.edge_dim_cppname, + 'n_elem2_1': 1, + 'n_elem2_2': 0, + 'axis': 0 + } + merge_config1 = concat_config_template.format(**merge_config1_params) + configs['merge_config1'] = merge_config1 + + return configs + + def _check_inputs(self): + #expected inputs: node_attr, edge_attr_aggr + assert(len(self.inputs)==2) + + node_attr = self.model.get_layer_output_variable(self.inputs[0]) + assert(node_attr.shape==[self.n_node, self.node_dim]) + + edge_attr_aggr = self.model.get_layer_output_variable(self.inputs[1]) + assert(edge_attr_aggr.shape==[self.n_node, self.edge_dim]) + + #expected outputs: node_update + assert(len(self.outputs)==1) + +class Aggregate(Layer): + def initialize(self): + self.n_node = self.attributes['n_node'] + self.n_edge = self.attributes['n_edge'] + self.node_dim = self.attributes['node_dim'] + self.edge_dim = self.attributes['edge_dim'] + self.out_dim = self.attributes['out_dim'] + self._check_inputs() + + self.n_edge_cppname, self.edge_dim_cppname = self.model.get_layer_output_variable('edge_attr').dim_names + self.n_node_cppname, self.node_dim_cppname = self.model.get_layer_output_variable('node_attr').dim_names + + aggr_name = f"layer{self.index}_out" + aggr_shape = [self.n_node, self.out_dim] + aggr_dims = ['N_NODE', f'LAYER{self.index}_OUT_DIM'] + self.add_output_variable(shape=aggr_shape, dim_names=aggr_dims, out_name=aggr_name, var_name=aggr_name, + precision=self.attributes.get('precision', None)) + + def function_cpp(self): + params = {} + params['config'] = 'aggregation_config{}'.format(self.index) + params['input_t'] = self.model.get_layer_output_variable('edge_attr').type.name + params['index_t'] = self.model.get_layer_output_variable('edge_index').type.name + params['output_t'] = self.get_output_variable().type.name + + params['edge_attr'] = self.attributes["inputs"][0] + params['edge_index'] = self.attributes["inputs"][1] + params['out'] = f"layer{self.index}_out" + return [self._function_template.format(**params)] + + def config_cpp(self): + params = self.get_Aggregate_params() + + top_config = self._config_template.format(**params) + top_config = top_config.split('\n')[:-1] + top_config = '\n'.join(top_config) + + sub_configs = self._config_misc() + for layer, config in sub_configs.items(): + config = [' ' + i for i in config.split('\n')] + config = '\n'.join(config) + + top_config += '\n\n' + top_config += config + + top_config += '\n};' + return top_config + + def get_Aggregate_params(self): + params = {} + params["index"] = self.index + params['n_node'] = self.attributes['n_node'] + params['node_dim'] = self.attributes['node_dim'] + params['n_edge'] = self.attributes['n_edge'] + params['edge_dim'] = self.attributes['edge_dim'] + params['table_t'] = f'layer{self.index}_t' + params['reuse'] = self.reuse_factor + + flow_map = {"source_to_target": 0, "target_to_source": 1} + params['flow'] = flow_map[self.model.reader.torch_model.flow] + + aggr_map = {"add": 0, "mean": 1, "max": 2} + params['aggr'] = aggr_map[self.model.reader.torch_model.aggr] + + params['io_type'] = 'io_parallel' + + return params + + def _config_misc(self): + # matrix configs + configs = {} + matrix_config_template = """struct {matrix_name}_config: nnet::matrix_config{{ + static const unsigned n_rows = {n_rows}; + static const unsigned n_cols = {n_cols}; + }};""" + + configs['edge_attr_config'] = matrix_config_template.format(matrix_name="edge_attr", + n_rows=self.n_edge_cppname, + n_cols=self.edge_dim_cppname) + + configs['edge_index_config'] = matrix_config_template.format(matrix_name="edge_index", + n_rows=self.n_edge_cppname, + n_cols="TWO") + + configs['edge_attr_aggr_config'] = matrix_config_template.format(matrix_name="edge_attr_aggr", + n_rows=self.n_node_cppname, + n_cols=f"LAYER{self.index}_OUT_DIM") + + aggr_params = self.get_Aggregate_params() + nested_duplicate = self._config_template.format(**aggr_params).split('\n') + nested_duplicate[0] = "struct nested_duplicate: nnet::aggregate_config{" + nested_duplicate = '\n'.join(nested_duplicate) + configs['nested_duplicate'] = nested_duplicate + + return configs + + def _check_inputs(self): + #expected inputs: edge_attr, edge_index + assert(len(self.inputs)==2) + + edge_attr = self.model.get_layer_output_variable(self.inputs[0]) + assert(edge_attr.shape==[self.n_edge, self.edge_dim]) + + edge_index = self.model.get_layer_output_variable(self.inputs[1]) + assert(edge_index.shape==[self.n_edge, 2]) + + #expected outputs: edge_attr_aggr + assert(len(self.outputs)==1) + layer_map = { 'Input' : Input, 'InputLayer' : Input, @@ -1875,6 +2462,9 @@ def _get_transforms_config(self, params): 'Transpose' : Transpose, 'GarNet' : GarNet, 'GarNetStack' : GarNetStack, + 'EdgeBlock' : EdgeBlock, + 'NodeBlock' : NodeBlock, + 'Aggregate' : Aggregate, # TensorFlow-specific layers: 'BiasAdd' : BiasAdd, } diff --git a/hls4ml/model/hls_model.py b/hls4ml/model/hls_model.py index 40b2e41727..b9e0e89ccd 100644 --- a/hls4ml/model/hls_model.py +++ b/hls4ml/model/hls_model.py @@ -467,8 +467,11 @@ def replace_node(self, old_node, new_node): self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) - def get_weights_data(self, layer_name, var_name): - return self.reader.get_weights_data(layer_name, var_name) + def get_weights_data(self, layer_name, var_name, module_name=None): + if module_name is not None: + return self.reader.get_weights_data(layer_name, var_name, module_name) + else: + return self.reader.get_weights_data(layer_name, var_name) def next_layer(self): self.index += 1 diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_array.h b/hls4ml/templates/vivado/nnet_utils/nnet_array.h index e39c534ac1..64546e712e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_array.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_array.h @@ -51,6 +51,36 @@ void transpose_3d( } } +struct matrix_config{ + static const unsigned n_rows = 10; + static const unsigned n_cols = 10; +}; +template +void vec_to_mat( //faster (I think) + data_T vec[CONFIG_T::n_rows*CONFIG_T::n_cols], + res_T mat[CONFIG_T::n_rows][CONFIG_T::n_cols] +) { + for (int r=0; r < CONFIG_T::n_rows; r++){ + for (int c=0; c < CONFIG_T::n_cols; c++){ + #pragma HLS UNROLL + mat[r][c] = vec[r*CONFIG_T::n_cols+c]; + } + } +} + +template +void mat_to_vec( //faster (I think) + data_T mat[CONFIG_T::n_rows][CONFIG_T::n_cols], + res_T vec[CONFIG_T::n_rows*CONFIG_T::n_cols] +) { + for (int r=0; r < CONFIG_T::n_rows; r++){ + for (int c=0; c diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index 756a627434..2298a1b172 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -62,7 +62,9 @@ void dense_resource_rf_leq_nin( ReuseLoop: for (int ir = 0; ir < rufactor; ir++) { - #pragma HLS PIPELINE II=1 rewind + if (!CONFIG_T::remove_pipeline_pragma) { + #pragma HLS PIPELINE II=1 rewind + } int w_index = ir; int in_index = ir; @@ -149,7 +151,9 @@ void dense_resource_rf_gt_nin_rem0( ReuseLoop: for (int ir = 0; ir < rufactor; ir++) { - #pragma HLS PIPELINE II=1 rewind + if (!CONFIG_T::remove_pipeline_pragma) { + #pragma HLS PIPELINE II=1 rewind + } w_index = ir; out_index = outidx[ir]/*outstep*/; @@ -213,7 +217,9 @@ void dense_resource_rf_gt_nin( ReuseLoop: for (int ir = 0; ir < rufactor; ir++) { - #pragma HLS PIPELINE II=1 rewind + if (!CONFIG_T::remove_pipeline_pragma) { + #pragma HLS PIPELINE II=1 rewind + } typename CONFIG_T::accum_t tmpmult[block_factor]; #pragma HLS ARRAY_PARTITION variable=tmpmult complete diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h new file mode 100644 index 0000000000..c062443b21 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h @@ -0,0 +1,439 @@ +#ifndef NNET_GRAPH_H_ +#define NNET_GRAPH_H_ + +#include "nnet_common.h" +#include "nnet_merge.h" +#include "nnet_dense.h" +#include "nnet_dense_resource.h" +#include "nnet_activation.h" +#include "nnet_array.h" +#include + +namespace nnet { + enum flow {source_to_target=0, target_to_source=1}; + enum aggr {aggr_sum=0, aggr_mean=1, aggr_max=2}; + + struct graph_config + { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float table_t; + + // Layer Sizes + static const unsigned n_node = 10; + static const unsigned n_edge = 20; + static const unsigned n_features = 3; + static const unsigned e_features = 4; + static const unsigned n_out = 4; + static const unsigned n_layers = 3; + + // message-passing parameters + static const unsigned aggr = aggr_sum; + static const unsigned flow = source_to_target; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const unsigned n_zeros = 0; + + static const bool no_aggr = false; //if no_aggr==true, then skip aggregation steps + }; + + struct aggregate_config + { + typedef float table_t; + static const unsigned n_node = 10; + static const unsigned n_edge = 20; + static const unsigned edge_dim = 4; + static const unsigned aggr = aggr_sum; + static const unsigned flow = source_to_target; + }; + + // division-LUT for mean-aggregation + inline float division(float input){ + return 1.0/input; + } + template + void init_div_table(typename CONFIG_T::table_t table_out[N_TABLE]){ + int j = 0; + typename CONFIG_T::table_t k = 1; + table_out[j] = k; + for(int i=1; i + void edge_divide(data_T edge_sum_i, index_T n_edges_i, res_T &edge_mean_i){ + // initialize LUT + #ifdef __HLS_SYN__ + bool initialized=false; + typename CONFIG_T::table_t div_table[CONFIG_T::n_edge]; + #else + static bool initialized=false; + static typename CONFIG_T::table_t div_table[CONFIG_T::n_edge]; + #endif + + if(!initialized){ + nnet::init_div_table(div_table); + initialized=true; + } + + if(CONFIG_T::io_type==io_parallel){ + #pragma HLS PIPELINE + } + + data_T reciprocal; + reciprocal = div_table[n_edges_i]; + edge_mean_i = edge_sum_i*reciprocal; + } + + template + void dense_mult_1lyr( + data_T data[CONFIG_T::dense_config1::n_in], + res_T res[CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::weight_t weights0[CONFIG_T::dense_config1::n_in*CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::bias_t biases0[CONFIG_T::dense_config1::n_out]) + { + nnet::dense_resource(data, res, weights0, biases0); + } + + template + void dense_mult_2lyr( + data_T data[CONFIG_T::dense_config1::n_in], + res_T res[CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config1::weight_t weights0[CONFIG_T::dense_config1::n_in*CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::bias_t biases0[CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config2::weight_t weights1[CONFIG_T::dense_config2::n_in*CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config2::bias_t biases1[CONFIG_T::dense_config2::n_out]) + { + data_T data0_logits[CONFIG_T::dense_config1::n_out]; + #pragma HLS ARRAY_PARTITION variable=data0_logits complete dim=0 + nnet::dense_resource(data, data0_logits, weights0, biases0); + data_T data0[CONFIG_T::dense_config1::n_out]; + #pragma HLS ARRAY_PARTITION variable=data0 complete dim=0 + nnet::relu(data0_logits, data0); + + nnet::dense_resource(data0, res, weights1, biases1); + } + + template + void dense_mult_3lyr( + data_T data[CONFIG_T::dense_config1::n_in], + res_T res[CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config1::weight_t weights0[CONFIG_T::dense_config1::n_in*CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::bias_t biases0[CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config2::weight_t weights1[CONFIG_T::dense_config2::n_in*CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config2::bias_t biases1[CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config3::weight_t weights2[CONFIG_T::dense_config3::n_in*CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config3::bias_t biases2[CONFIG_T::dense_config3::n_out]) + { + data_T data0_logits[CONFIG_T::dense_config1::n_out]; + #pragma HLS ARRAY_PARTITION variable=data0_logits complete dim=0 + nnet::dense_resource(data, data0_logits, weights0, biases0); + data_T data0[CONFIG_T::dense_config1::n_out]; + #pragma HLS ARRAY_PARTITION variable=data0 complete dim=0 + nnet::relu(data0_logits, data0); + + data_T data1_logits[CONFIG_T::dense_config2::n_out]; + #pragma HLS ARRAY_PARTITION variable=data1_logits complete dim=0 + nnet::dense_resource(data0, data1_logits, weights1, biases1); + data_T data1[CONFIG_T::dense_config2::n_out]; + #pragma HLS ARRAY_PARTITION variable=data1 complete dim=0 + nnet::relu(data1_logits, data1); + + nnet::dense_resource(data1, res, weights2, biases2); + } + + template + void dense_mult_4lyr( + data_T data[CONFIG_T::dense_config1::n_in], + res_T res[CONFIG_T::dense_config4::n_out], + typename CONFIG_T::dense_config1::weight_t weights0[CONFIG_T::dense_config1::n_in*CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::bias_t biases0[CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config2::weight_t weights1[CONFIG_T::dense_config2::n_in*CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config2::bias_t biases1[CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config3::weight_t weights2[CONFIG_T::dense_config3::n_in*CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config3::bias_t biases2[CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config4::weight_t weights3[CONFIG_T::dense_config4::n_in*CONFIG_T::dense_config4::n_out], + typename CONFIG_T::dense_config4::bias_t biases3[CONFIG_T::dense_config4::n_out]) + { + data_T data0_logits[CONFIG_T::dense_config1::n_out]; + #pragma HLS ARRAY_PARTITION variable=data0_logits complete dim=0 + nnet::dense_resource(data, data0_logits, weights0, biases0); + data_T data0[CONFIG_T::dense_config1::n_out]; + #pragma HLS ARRAY_PARTITION variable=data0 complete dim=0 + nnet::relu(data0_logits, data0); + + data_T data1_logits[CONFIG_T::dense_config2::n_out]; + #pragma HLS ARRAY_PARTITION variable=data1_logits complete dim=0 + nnet::dense_resource(data0, data1_logits, weights1, biases1); + data_T data1[CONFIG_T::dense_config2::n_out]; + #pragma HLS ARRAY_PARTITION variable=data1 complete dim=0 + nnet::relu(data1_logits, data1); + + data_T data2_logits[CONFIG_T::dense_config3::n_out]; + #pragma HLS ARRAY_PARTITION variable=data2_logits complete dim=0 + nnet::dense_resource(data1, data2_logits, weights2, biases2); + data_T data2[CONFIG_T::dense_config3::n_out]; + #pragma HLS ARRAY_PARTITION variable=data2 complete dim=0 + nnet::relu(data2_logits, data2); + + nnet::dense_resource(data2, res, weights3, biases3); + } + + template + void aggregate( + data_T edge_attr_1D[CONFIG_T::n_edge*CONFIG_T::edge_dim], + index_T edge_index_1D[CONFIG_T::n_edge*2], + res_T edge_attr_aggr_1D[CONFIG_T::n_node*CONFIG_T::edge_dim]) + { + //initialize arrays + // 1. edge_attr (input) + data_T edge_attr[CONFIG_T::n_edge][CONFIG_T::edge_dim]; + #pragma HLS ARRAY_PARTITION variable=edge_attr complete dim=0 + nnet::vec_to_mat(edge_attr_1D, edge_attr); + + // 2. edge_index (input) + index_T edge_index[CONFIG_T::n_edge][2]; + #pragma HLS ARRAY_PARTITION variable=edge_index complete dim=0 + nnet::vec_to_mat(edge_index_1D, edge_index); + + //3. num_edge_per_node (intermediate), 4. edge_aggr_mask (intermediate) + index_T num_edge_per_node[CONFIG_T::n_node]; + #pragma HLS ARRAY_PARTITION variable=num_edge_per_node complete dim=0 + ap_uint<1> edge_aggr_mask[CONFIG_T::n_node]; + #pragma HLS ARRAY_PARTITION variable=edge_aggr_mask complete dim=0 + for(int i=0; i edge_attr_aggr[r][j] ? edge_attr[i][j] : edge_attr_aggr[r][j]; + } + } + } + + // sum --> mean + if(CONFIG_T::aggr == aggr_mean){ + for(int i=0; i < CONFIG_T::n_node; i++){ + for (int j=0; j(edge_attr_aggr[i][j], num_edge_per_node[i], edge_mean_j); + edge_attr_aggr[i][j] = edge_mean_j; + } + } + } + + // None --> max + if(CONFIG_T::aggr == aggr_max){ //note: the edge_update_aggr array has been initialized but IS NOT ZEROS + for(int i=0; i < CONFIG_T::n_node; i++){ + for(int j=0; j output vec + nnet::mat_to_vec(edge_attr_aggr, edge_attr_aggr_1D); + } + + template + void edgeblock( + data_T node_attr_1D[CONFIG_T::n_node*CONFIG_T::node_dim], + data_T edge_attr_1D[CONFIG_T::n_edge*CONFIG_T::edge_dim], + index_T edge_index_1D[CONFIG_T::n_edge*2], + res_T edge_update_1D[CONFIG_T::n_edge*CONFIG_T::out_dim], + typename CONFIG_T::dense_config1::weight_t core_edge_w0[CONFIG_T::dense_config1::n_in*CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::bias_t core_edge_b0[CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config2::weight_t core_edge_w1[CONFIG_T::dense_config2::n_in*CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config2::bias_t core_edge_b1[CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config3::weight_t core_edge_w2[CONFIG_T::dense_config3::n_in*CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config3::bias_t core_edge_b2[CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config4::weight_t core_edge_w3[CONFIG_T::dense_config4::n_in*CONFIG_T::dense_config4::n_out], + typename CONFIG_T::dense_config4::bias_t core_edge_b3[CONFIG_T::dense_config4::n_out]) + { + //initialize arrays + // 1. node_attr (input) + data_T node_attr[CONFIG_T::n_node][CONFIG_T::node_dim]; + #pragma HLS ARRAY_PARTITION variable=node_attr complete dim=0 + nnet::vec_to_mat(node_attr_1D, node_attr); + + // 2. edge_attr (input) + data_T edge_attr[CONFIG_T::n_edge][CONFIG_T::edge_dim]; + #pragma HLS ARRAY_PARTITION variable=edge_attr complete dim=0 + nnet::vec_to_mat(edge_attr_1D, edge_attr); + + // 3. edge_index (input) + index_T edge_index[CONFIG_T::n_edge][2]; + #pragma HLS ARRAY_PARTITION variable=edge_index complete dim=0 + nnet::vec_to_mat(edge_index_1D, edge_index); + + // 4. edge_update (output) + res_T edge_update[CONFIG_T::n_edge][CONFIG_T::out_dim]; + #pragma HLS ARRAY_PARTITION variable=edge_update complete dim=0 + + int sender_col; + int receiver_col; + if(CONFIG_T::flow == source_to_target){ + sender_col = 0; + receiver_col = 1; + } + else{ + sender_col = 1; + receiver_col = 0; + } + + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + edge_loop: for(int i = 0; i < CONFIG_T::n_edge; i++) { //for each edge + #pragma HLS UNROLL + + // get sender, receiver indices + index_T s = edge_index[i][sender_col]; + index_T r = edge_index[i][receiver_col]; + + // construct NN input: + data_T node_concat[2*CONFIG_T::node_dim]; + #pragma HLS ARRAY_PARTITION variable=node_concat complete dim=0 + nnet::concatenate1d(node_attr[r], node_attr[s], node_concat); + data_T phi_input[CONFIG_T::edge_dim + 2*CONFIG_T::node_dim]; + #pragma HLS ARRAY_PARTITION variable=phi_input complete dim=0 + nnet::concatenate1d(node_concat, edge_attr[i], phi_input); + + // send it through NN + if(CONFIG_T::n_layers == 1){ + nnet::dense_mult_1lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0); + } + else if(CONFIG_T::n_layers == 2){ + nnet::dense_mult_2lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1); + } + else if(CONFIG_T::n_layers == 3){ + nnet::dense_mult_3lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1, core_edge_w2, core_edge_b2); + } + else if(CONFIG_T::n_layers == 4){ + nnet::dense_mult_4lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1, core_edge_w2, core_edge_b2, core_edge_w3, core_edge_b3); + } + } + + //output arrays --> output vectors + // 1. edge_update_1D + nnet::mat_to_vec(edge_update, edge_update_1D); + } + + template + void nodeblock( + data_T node_attr_1D[CONFIG_T::n_node*CONFIG_T::node_dim], + data_T edge_attr_aggr_1D[CONFIG_T::n_node*CONFIG_T::edge_dim], + res_T node_update_1D[CONFIG_T::n_node*CONFIG_T::out_dim], + typename CONFIG_T::dense_config1::weight_t core_node_w0[CONFIG_T::dense_config1::n_in*CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config1::bias_t core_node_b0[CONFIG_T::dense_config1::n_out], + typename CONFIG_T::dense_config2::weight_t core_node_w1[CONFIG_T::dense_config2::n_in*CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config2::bias_t core_node_b1[CONFIG_T::dense_config2::n_out], + typename CONFIG_T::dense_config3::weight_t core_node_w2[CONFIG_T::dense_config3::n_in*CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config3::bias_t core_node_b2[CONFIG_T::dense_config3::n_out], + typename CONFIG_T::dense_config4::weight_t core_node_w3[CONFIG_T::dense_config4::n_in*CONFIG_T::dense_config4::n_out], + typename CONFIG_T::dense_config4::bias_t core_node_b3[CONFIG_T::dense_config4::n_out]) + { + //initialize arrays + //1. node_attr (input) + data_T node_attr[CONFIG_T::n_node][CONFIG_T::node_dim]; + #pragma HLS ARRAY_PARTITION variable=node_attr complete dim=0 + nnet::vec_to_mat(node_attr_1D, node_attr); + + //2. edge_attr_aggr (input) + data_T edge_attr_aggr[CONFIG_T::n_node][CONFIG_T::edge_dim]; + #pragma HLS ARRAY_PARTITION variable=edge_attr_aggr complete dim=0 + nnet::vec_to_mat(edge_attr_aggr_1D, edge_attr_aggr); + + // 3. node_update (output) + res_T node_update[CONFIG_T::n_node][CONFIG_T::out_dim]; + #pragma HLS ARRAY_PARTITION variable=node_update complete dim=0 + + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + node_loop: for(int i = 0; i < CONFIG_T::n_node; i++){ //for each node + #pragma HLS UNROLL + + // construct NN input: + data_T phi_input[CONFIG_T::edge_dim + CONFIG_T::node_dim]; + #pragma HLS ARRAY_PARTITION variable=phi_input complete dim=0 + nnet::concatenate1d(node_attr[i], edge_attr_aggr[i], phi_input); + + // send it through NN + if(CONFIG_T::n_layers == 1){ + nnet::dense_mult_1lyr(phi_input, node_update[i], core_node_w0, core_node_b0); + } + else if(CONFIG_T::n_layers == 2){ + nnet::dense_mult_2lyr(phi_input, node_update[i], core_node_w0, core_node_b0, core_node_w1, core_node_b1); + } + else if(CONFIG_T::n_layers == 3){ + nnet::dense_mult_3lyr(phi_input, node_update[i], core_node_w0, core_node_b0, core_node_w1, core_node_b1, core_node_w2, core_node_b2); + } + else { // CONFIG_T::n_layers == 4 + nnet::dense_mult_4lyr(phi_input, node_update[i], core_node_w0, core_node_b0, core_node_w1, core_node_b1, core_node_w2, core_node_b2, core_node_w3, core_node_b3); + } + } + + // output array --> output vector + nnet::mat_to_vec(node_update, node_update_1D); + + } + +} + +#endif diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py index 53cf3536ff..f91b951618 100644 --- a/hls4ml/templates/vivado_template.py +++ b/hls4ml/templates/vivado_template.py @@ -19,6 +19,7 @@ typedef {bias_t} bias_t; typedef {weight_t} weight_t; typedef {index_t} index_t; + static const bool remove_pipeline_pragma = {remove_pipeline_pragma}; template using product = nnet::product::{product_type}; }};\n""" @@ -345,6 +346,50 @@ garnet_stack_config_template = (garnet_stack_base_config_template, garnet_stack_sublayer_config_template) +edgeblock_config_template = """struct config{index}: nnet::graph_config{{ + typedef {bias_t} bias_t; + typedef {weight_t} weight_t; + typedef {table_t} table_t; + static const unsigned n_node = {n_node}; + static const unsigned n_edge = {n_edge}; + static const unsigned node_dim = {node_dim}; + static const unsigned edge_dim = {edge_dim}; + static const unsigned out_dim = {out_dim}; + static const unsigned n_layers = {n_layers}; + static const unsigned flow = {flow}; + static const unsigned io_type = nnet::{io_type}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {n_zeros}; + static const bool io_stream = false; +}};""" + +nodeblock_config_template = """struct config{index}: nnet::graph_config{{ + typedef {bias_t} bias_t; + typedef {weight_t} weight_t; + typedef {table_t} table_t; + static const unsigned n_node = {n_node}; + static const unsigned n_edge = {n_edge}; + static const unsigned node_dim = {node_dim}; + static const unsigned edge_dim = {edge_dim}; + static const unsigned out_dim = {out_dim}; + static const unsigned n_layers = {n_layers}; + static const unsigned io_type = nnet::{io_type}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {n_zeros}; + static const bool io_stream = false; +}};""" + +aggregate_config_template = """struct aggregation_config{index}: nnet::aggregate_config{{ + typedef {table_t} table_t; + static const unsigned n_node = {n_node}; + static const unsigned n_edge = {n_edge}; + static const unsigned edge_dim = {edge_dim}; + static const unsigned aggr = {aggr}; + static const unsigned flow = {flow}; + static const unsigned io_type = nnet::{io_type}; + static const unsigned reuse_factor = {reuse}; + static const bool io_stream = false; +}};""" dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' @@ -367,6 +412,9 @@ transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {config}>({input}, {output});' garnet_function_template = 'nnet::garnet{impl}<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});' garnet_stack_function_template = 'nnet::garnet_stack<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});' +edgeblock_function_template = 'nnet::edgeblock<{input_t}, {index_t}, {output_t}, {config}>({node_attr}, {edge_attr}, {edge_index}, {out}, {w0}, {b0}, {w1}, {b1}, {w2}, {b2}, {w3}, {b3});' +nodeblock_function_template = 'nnet::nodeblock<{input_t}, {output_t}, {config}>({node_attr}, {edge_attr_aggr}, {out}, {w0}, {b0}, {w1}, {b1}, {w2}, {b2}, {w3}, {b3});' +aggregate_function_template = 'nnet::aggregate<{input_t}, {index_t}, {output_t}, {config}>({edge_attr}, {edge_index}, {out});' dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h'] batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] @@ -381,6 +429,21 @@ resize_include_list = ['nnet_utils/nnet_image.h', 'nnet_utils/nnet_image_stream.h'] transpose_include_list = ['nnet_utils/nnet_array.h'] garnet_include_list = ['nnet_utils/nnet_garnet.h'] +edgeblock_include_list = ['nnet_utils/nnet_common.h', + 'nnet_utils/nnet_dense.h', + 'nnet_utils/nnet_dense_resource.h', + 'nnet_utils/nnet_activation.h', + 'nnet_utils/nnet_graph.h', + 'nnet_utils/nnet_merge.h', + 'nnet_utils/nnet_array.h'] +nodeblock_include_list = ['nnet_utils/nnet_common.h', + 'nnet_utils/nnet_dense.h', + 'nnet_utils/nnet_dense_resource.h', + 'nnet_utils/nnet_activation.h', + 'nnet_utils/nnet_graph.h', + 'nnet_utils/nnet_merge.h', + 'nnet_utils/nnet_array.h'] +aggregate_include_list = ['nnet_utils/nnet_graph.h'] class VivadoBackend(Backend): def __init__(self): @@ -411,6 +474,9 @@ def __init__(self): self.register_templates('Transpose' , transpose_function_template, transpose_config_template, transpose_include_list) self.register_templates('GarNet' , garnet_function_template, garnet_config_template, garnet_include_list) self.register_templates('GarNetStack' , garnet_stack_function_template,garnet_stack_config_template, garnet_include_list) + self.register_templates('EdgeBlock' , edgeblock_function_template, edgeblock_config_template, edgeblock_include_list) + self.register_templates('NodeBlock' , nodeblock_function_template, nodeblock_config_template, nodeblock_include_list) + self.register_templates('Aggregate' , aggregate_function_template, aggregate_config_template, aggregate_include_list) def get_valid_reuse_factors(self, layer): n_in = 0 diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index de3d0a02e5..79e8a7b763 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -290,6 +290,18 @@ def config_from_pytorch_model(model, granularity='model', default_precision='ap_ return config +def config_from_pyg_model(model, granularity='model', default_precision='ap_fixed<16,6>', default_index_precision='ap_uint<16>', default_reuse_factor=1): + config = {} + + model_config = {} + model_config['Precision'] = default_precision + model_config['IndexPrecision'] = default_index_precision + model_config['ReuseFactor'] = default_reuse_factor + model_config['Strategy'] = 'Latency' + + config['Model'] = model_config + + return config def config_from_onnx_model(model, granularity='model', default_precision='ap_fixed<16,6>', default_reuse_factor=1): """Generate configuration dictionary from an ONNX model. @@ -335,4 +347,4 @@ def config_from_onnx_model(model, granularity='model', default_precision='ap_fix config['Model'] = model_config - return config \ No newline at end of file + return config diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 8b58488650..636857c63d 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -297,7 +297,9 @@ def write_defines(self, model): if '//hls-fpga-machine-learning insert numbers' in line: newline = line numbers = OrderedDict.fromkeys([layer.get_numbers_cpp() for layer in model.get_layers()]) - newline += ''.join(numbers) + numbers = set('\n'.join(numbers).split('\n')) #we want a unique set of macro declarations, since some of the macros are shared between different NN-blocks + newline += '\n'.join([i for i in numbers if i!='']) + newline += '\n' #for formatting purposes elif '//hls-fpga-machine-learning insert layer-precision' in line: newline = line @@ -658,7 +660,14 @@ def keras_model_representer(dumper, keras_model): pass with open(model.config.get_output_dir() + '/' + config_filename, 'w') as file: - yaml.dump(model.config.config, file) + try: + yaml.dump(model.config.config, file) + except ValueError: + import torch + model_path = model.config.get_output_dir() + "/torch_model_state_dict.pt" + torch.save(model.config.config["PytorchModel"].state_dict(), model_path) + model.config.config["PytorchModel"] = model_path + yaml.dump(model.config.config, file) def write_tar(self, model): ################### From d9db9a9028df92cd503c6c961906eed8b3202378 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Tue, 17 Aug 2021 15:09:18 -0400 Subject: [PATCH 02/13] added (started) pyg block_handlers --- hls4ml/converters/__init__.py | 6 +- hls4ml/converters/pyg_to_hls.py | 96 +++++++++-------------------- hls4ml/converters/pytorch_to_hls.py | 2 +- 3 files changed, 33 insertions(+), 71 deletions(-) diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index ac066baef7..a3c7126241 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -11,7 +11,7 @@ #----------Make converters available if the libraries can be imported----------# try: from hls4ml.converters.pytorch_to_hls import pytorch_to_hls, get_supported_pytorch_layers, register_pytorch_layer_handler - from hls4ml.converters.pyg_to_hls import pyg_to_hls + from hls4ml.converters.pyg_to_hls import pyg_to_hls, get_supported_pyg_blocks, register_pyg_block_handler __pytorch_enabled__ = True except ImportError: warnings.warn("WARNING: Pytorch converter is not enabled!") @@ -32,7 +32,7 @@ __tensorflow_enabled__ = False #----------Layer handling register----------# -model_types = ['keras', 'pytorch', 'onnx'] +model_types = ['keras', 'pytorch', 'onnx', 'pyg'] for model_type in model_types: for module in os.listdir(os.path.dirname(__file__) + '/{}'.format(model_type)): @@ -53,6 +53,8 @@ register_pytorch_layer_handler(layer, func) elif model_type == 'onnx': register_onnx_layer_handler(layer, func) + elif model_type == 'pyg': + register_pyg_block_handler(layer, func) except ImportError: continue diff --git a/hls4ml/converters/pyg_to_hls.py b/hls4ml/converters/pyg_to_hls.py index 34019cd1c3..62354fde68 100644 --- a/hls4ml/converters/pyg_to_hls.py +++ b/hls4ml/converters/pyg_to_hls.py @@ -1,5 +1,4 @@ from __future__ import print_function -import torch from hls4ml.converters.pytorch_to_hls import PyTorchModelReader from hls4ml.model.hls_model import HLSModel @@ -47,8 +46,25 @@ def get_weights_data(self, layer_name, var_name, module_name=None): return data -def pyg_to_hls(config): +# EdgeBlock/NodeBlock/Aggregate handlers +block_handlers = {} + +def register_pyg_block_handler(block_name, handler_func): + if block_name in block_handlers: + raise Exception('Block {} already registered'.format(block_name)) + else: + block_handlers[block_name] = handler_func + +def get_supported_pyg_blocks(): + return list(block_handlers.keys()) +def pyg_handler(*args): + def decorator(function): + function.handles = [arg for arg in args] + return function + return decorator + +def pyg_to_hls(config): forward_dict = config['ForwardDictionary'] activate_final = config['ActivateFinal'] @@ -64,7 +80,6 @@ def pyg_to_hls(config): node_dim = reader.node_dim edge_dim = reader.edge_dim - # initiate layer list with inputs: node_attr, edge_attr, edge_index layer_list = [] input_shapes = reader.input_shape @@ -95,81 +110,26 @@ def pyg_to_hls(config): 'precision': int_type } layer_list.append(EdgeIndex_layer) - last_node_update = "node_attr" - last_edge_update = "edge_attr" + update_dict = {"last_node_update": "node_attr", "last_edge_update": "edge_attr", "last_edge_aggr_update": None} - # If the first block is a NodeBlock, we need a layer to construct the initial edge_aggregates + # If the first block is a NodeBlock, we need an initial Aggregate block to construct the initial edge_aggregates if forward_dict[list(forward_dict.keys())[0]] == "NodeBlock": - aggr_layer = {"name": "aggr1", - "class_name": "Aggregate", - "n_node": n_node, - "n_edge": n_edge, - "node_dim": node_dim, - "edge_dim": edge_dim, - "precision": fp_type, - "out_dim": edge_dim, - "inputs": ["edge_attr", "edge_index"], - "outputs": ["edge_attr_aggr"]} - layer_list.append(aggr_layer) - last_edge_aggr_update = "edge_attr_aggr" - else: last_edge_aggr_update = None + index = len(layer_list)+1 + layer_dict, update_dict = block_handlers["Aggregate"](index, fp_type, update_dict, n_node, + n_edge, node_dim, edge_dim) + layer_list.append(layer_dict) # complete the layer list for i, (key, val) in enumerate(forward_dict.items()): - layer_dict = { - "name": key, - "class_name": val, - "n_node": n_node, - "n_edge": n_edge, - "node_dim": node_dim, - "edge_dim": edge_dim, - "precision": fp_type - } - - # get n_layers, out_dim - model = config['PytorchModel'] - torch_block = getattr(model, key) - try: - torch_layers = torch_block.layers._modules - except AttributeError: - torch_layers = torch_block._modules - - lcount = 0 - for lname, l in torch_layers.items(): - if isinstance(l, torch.nn.modules.linear.Linear): - lcount += 1 - last_layer = l - layer_dict["n_layers"] = lcount - layer_dict["out_dim"] = last_layer.out_features - # get inputs, outputs - if val == "NodeBlock": - index = len(layer_list) + 1 - layer_dict["inputs"] = [last_node_update, last_edge_aggr_update] - layer_dict["outputs"] = [f"layer{index}_out"] - last_node_update = f"layer{index}_out" - layer_list.append(layer_dict) - elif val == "EdgeBlock": - index = len(layer_list) + 1 - layer_dict["inputs"] = [last_node_update, last_edge_update, "edge_index"] - layer_dict["outputs"] = [f"layer{index}_out"] - last_edge_update = f"layer{index}_out" - layer_list.append(layer_dict) + index = len(layer_list)+1 + layer_dict, update_dict = block_handlers[val](key, config, update_dict, index, n_node, n_edge, node_dim, edge_dim) + layer_list.append(layer_dict) # if val==EdgeBlock and this is not the final graph-block, follow it with an aggregation layer if (val == "EdgeBlock") and (i < len(forward_dict) - 1): index = len(layer_list) + 1 - layer_dict = {"name": f"aggr{index}", - "class_name": "Aggregate", - "n_node": n_node, - "n_edge": n_edge, - "node_dim": node_dim, - "edge_dim": edge_dim, - "precision": fp_type, - "out_dim": edge_dim, - "inputs": [last_edge_update, "edge_index"], - "outputs": [f"layer{index}_out"]} - last_edge_aggr_update = f"layer{index}_out" + layer_dict, update_dict = block_handlers["Aggregate"](index, update_dict, n_node, n_edge, node_dim, edge_dim) layer_list.append(layer_dict) if activate_final is not None: diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index a3daebb359..02552aefc6 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -119,7 +119,7 @@ def pytorch_to_hls(config): ----- Only sequential pytorch models are supported for now. """ - + #This is a list of dictionaries to hold all the layer info we need to generate HLS layer_list = [] From 52ffd2bb312863cfebd733c32223ba29e8d2d291 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Tue, 17 Aug 2021 15:22:55 -0400 Subject: [PATCH 03/13] added pyg graph_handlers --- .../pyg/interaction_network_blocks.py | 66 +++++++++++++++++++ hls4ml/converters/pyg_to_hls.py | 28 ++++---- 2 files changed, 80 insertions(+), 14 deletions(-) create mode 100644 hls4ml/converters/pyg/interaction_network_blocks.py diff --git a/hls4ml/converters/pyg/interaction_network_blocks.py b/hls4ml/converters/pyg/interaction_network_blocks.py new file mode 100644 index 0000000000..e2aa14c538 --- /dev/null +++ b/hls4ml/converters/pyg/interaction_network_blocks.py @@ -0,0 +1,66 @@ +import numpy as np +from hls4ml.converters.pyg_to_hls import pyg_handler + +def parse_GraphBlock(block_name, config, n_node, n_edge, node_dim, edge_dim): + layer_dict = { + "name": block_name, + "n_node": n_node, + "n_edge": n_edge, + "node_dim": node_dim, + "edge_dim": edge_dim, + } + + # get n_layers, out_dim + model = config['PytorchModel'] + torch_block = getattr(model, block_name) + try: + torch_layers = torch_block.layers._modules + except AttributeError: + torch_layers = torch_block._modules + + lcount = 0 + for lname, l in torch_layers.items(): + if l.__class__.__name__=="Linear": + lcount += 1 + last_layer = l + layer_dict["n_layers"] = lcount + layer_dict["out_dim"] = last_layer.out_features + return layer_dict + +@pyg_handler('NodeBlock') +def parse_NodeBlock(block_name, config, update_dict, index, n_node, n_edge, node_dim, edge_dim): + layer_dict = parse_GraphBlock(block_name, config, n_node, n_edge, node_dim, edge_dim) + layer_dict["class_name"] = "NodeBlock" + layer_dict["inputs"] = [update_dict["last_node_update"], update_dict["last_edge_aggr_update"]] + layer_dict["outputs"] = [f"layer{index}_out"] + update_dict["last_node_update"] = f"layer{index}_out" + return layer_dict, update_dict + +@pyg_handler('EdgeBlock') +def parse_EdgeBlock(block_name, config, update_dict, index, n_node, n_edge, node_dim, edge_dim): + layer_dict = parse_GraphBlock(block_name, config, n_node, n_edge, node_dim, edge_dim) + layer_dict["class_name"] = "EdgeBlock" + layer_dict["inputs"] = [update_dict["last_node_update"], update_dict["last_edge_update"], "edge_index"] + layer_dict["outputs"] = [f"layer{index}_out"] + update_dict["last_edge_update"] = f"layer{index}_out" + return layer_dict, update_dict + +@pyg_handler('Aggregate') +def parse_Aggregate(block_name, config, update_dict, index, n_node, n_edge, node_dim, edge_dim): + layer_dict = {"name": f"aggr{index}", + "class_name": "Aggregate", + "n_node": n_node, + "n_edge": n_edge, + "node_dim": node_dim, + "edge_dim": edge_dim, + "out_dim": edge_dim, + "inputs": [update_dict["last_edge_update"], "edge_index"], + "outputs": [f"layer{index}_out"]} + update_dict["last_edge_aggr_update"] = f"layer{index}_out" + return layer_dict, update_dict + +IN_handlers = { + "NodeBlock": parse_NodeBlock, + "EdgeBlock": parse_EdgeBlock, + "Aggregate": parse_Aggregate +} \ No newline at end of file diff --git a/hls4ml/converters/pyg_to_hls.py b/hls4ml/converters/pyg_to_hls.py index 62354fde68..7f057bdee2 100644 --- a/hls4ml/converters/pyg_to_hls.py +++ b/hls4ml/converters/pyg_to_hls.py @@ -1,10 +1,10 @@ from __future__ import print_function +from collections import OrderedDict from hls4ml.converters.pytorch_to_hls import PyTorchModelReader from hls4ml.model.hls_model import HLSModel from hls4ml.templates import get_backend - class PygModelReader(PyTorchModelReader): def __init__(self, config): super().__init__(config) @@ -112,26 +112,26 @@ def pyg_to_hls(config): layer_list.append(EdgeIndex_layer) update_dict = {"last_node_update": "node_attr", "last_edge_update": "edge_attr", "last_edge_aggr_update": None} - # If the first block is a NodeBlock, we need an initial Aggregate block to construct the initial edge_aggregates - if forward_dict[list(forward_dict.keys())[0]] == "NodeBlock": - index = len(layer_list)+1 - layer_dict, update_dict = block_handlers["Aggregate"](index, fp_type, update_dict, n_node, - n_edge, node_dim, edge_dim) - layer_list.append(layer_dict) + # insert an aggregation step before each NodeBlock + aggr_count = 0 + forward_dict_new = OrderedDict() + for key, val in forward_dict.items(): + if val=="NodeBlock": + aggr_count += 1 + aggr_key = f"aggr{aggr_count}" + aggr_val = "Aggregate" + forward_dict_new[aggr_key] = aggr_val + forward_dict_new[key] = val + print(f"forward_dict: {forward_dict}") + print(f"forward_dict_new: {forward_dict_new}") # complete the layer list - for i, (key, val) in enumerate(forward_dict.items()): + for i, (key, val) in enumerate(forward_dict_new.items()): # get inputs, outputs index = len(layer_list)+1 layer_dict, update_dict = block_handlers[val](key, config, update_dict, index, n_node, n_edge, node_dim, edge_dim) layer_list.append(layer_dict) - # if val==EdgeBlock and this is not the final graph-block, follow it with an aggregation layer - if (val == "EdgeBlock") and (i < len(forward_dict) - 1): - index = len(layer_list) + 1 - layer_dict, update_dict = block_handlers["Aggregate"](index, update_dict, n_node, n_edge, node_dim, edge_dim) - layer_list.append(layer_dict) - if activate_final is not None: act_dict = { 'name': 'final_act', From de38c2b64af898227522b9124d479e38e79d6e82 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Tue, 17 Aug 2021 15:27:30 -0400 Subject: [PATCH 04/13] added 'check_forward_dict' --- hls4ml/converters/__init__.py | 5 ++++- hls4ml/converters/pyg/__init__.py | 0 hls4ml/converters/pyg_to_hls.py | 4 +--- 3 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 hls4ml/converters/pyg/__init__.py diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index a3c7126241..70f1398caa 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -270,11 +270,14 @@ def convert_from_pytorch_model(model, input_shape, output_dir='my-hls-test', pro return pytorch_to_hls(config) +def check_forward_dict(model, forward_dictionary): + for key in forward_dictionary: + assert(hasattr(model, key)) def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim, forward_dictionary=None, activate_final=None, output_dir='my-hls-test', project_name='myproject', fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}): - + check_forward_dict(model, forward_dictionary) config = create_vivado_config( output_dir=output_dir, project_name=project_name, diff --git a/hls4ml/converters/pyg/__init__.py b/hls4ml/converters/pyg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/converters/pyg_to_hls.py b/hls4ml/converters/pyg_to_hls.py index 7f057bdee2..3da028a117 100644 --- a/hls4ml/converters/pyg_to_hls.py +++ b/hls4ml/converters/pyg_to_hls.py @@ -116,14 +116,12 @@ def pyg_to_hls(config): aggr_count = 0 forward_dict_new = OrderedDict() for key, val in forward_dict.items(): - if val=="NodeBlock": + if val == "NodeBlock": aggr_count += 1 aggr_key = f"aggr{aggr_count}" aggr_val = "Aggregate" forward_dict_new[aggr_key] = aggr_val forward_dict_new[key] = val - print(f"forward_dict: {forward_dict}") - print(f"forward_dict_new: {forward_dict_new}") # complete the layer list for i, (key, val) in enumerate(forward_dict_new.items()): From ffad2028f98e801f264833f8ddab82d601c09f25 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Tue, 17 Aug 2021 16:35:48 -0400 Subject: [PATCH 05/13] updated error handling --- hls4ml/converters/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 70f1398caa..52fd63fe2f 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -272,7 +272,11 @@ def convert_from_pytorch_model(model, input_shape, output_dir='my-hls-test', pro def check_forward_dict(model, forward_dictionary): for key in forward_dictionary: - assert(hasattr(model, key)) + try: + block = getattr(model, key) + except AttributeError: + raise AttributeError(f'Model is missing module "{key}" that is present in the provided forward dictionary; Check compatability') + def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim, forward_dictionary=None, activate_final=None, output_dir='my-hls-test', project_name='myproject', From 31e7b1239c3900c9f03937ded78afa229f08c9f8 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Tue, 24 Aug 2021 11:53:39 -0400 Subject: [PATCH 06/13] updated naming convention: Aggregate-->EdgeAggregate, aggregate-->edge_aggregate --- hls4ml/converters/pyg/interaction_network_blocks.py | 8 ++++---- hls4ml/converters/pyg_to_hls.py | 2 +- hls4ml/model/hls_layers.py | 6 +++--- hls4ml/templates/vivado/nnet_utils/nnet_graph.h | 4 ++-- hls4ml/templates/vivado_template.py | 8 ++++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/hls4ml/converters/pyg/interaction_network_blocks.py b/hls4ml/converters/pyg/interaction_network_blocks.py index e2aa14c538..c45d3af5b1 100644 --- a/hls4ml/converters/pyg/interaction_network_blocks.py +++ b/hls4ml/converters/pyg/interaction_network_blocks.py @@ -45,10 +45,10 @@ def parse_EdgeBlock(block_name, config, update_dict, index, n_node, n_edge, node update_dict["last_edge_update"] = f"layer{index}_out" return layer_dict, update_dict -@pyg_handler('Aggregate') -def parse_Aggregate(block_name, config, update_dict, index, n_node, n_edge, node_dim, edge_dim): +@pyg_handler('EdgeAggregate') +def parse_EdgeAggregate(block_name, config, update_dict, index, n_node, n_edge, node_dim, edge_dim): layer_dict = {"name": f"aggr{index}", - "class_name": "Aggregate", + "class_name": "EdgeAggregate", "n_node": n_node, "n_edge": n_edge, "node_dim": node_dim, @@ -62,5 +62,5 @@ def parse_Aggregate(block_name, config, update_dict, index, n_node, n_edge, node IN_handlers = { "NodeBlock": parse_NodeBlock, "EdgeBlock": parse_EdgeBlock, - "Aggregate": parse_Aggregate + "EdgeAggregate": parse_EdgeAggregate } \ No newline at end of file diff --git a/hls4ml/converters/pyg_to_hls.py b/hls4ml/converters/pyg_to_hls.py index 3da028a117..99f5da7fa1 100644 --- a/hls4ml/converters/pyg_to_hls.py +++ b/hls4ml/converters/pyg_to_hls.py @@ -119,7 +119,7 @@ def pyg_to_hls(config): if val == "NodeBlock": aggr_count += 1 aggr_key = f"aggr{aggr_count}" - aggr_val = "Aggregate" + aggr_val = "EdgeAggregate" forward_dict_new[aggr_key] = aggr_val forward_dict_new[key] = val diff --git a/hls4ml/model/hls_layers.py b/hls4ml/model/hls_layers.py index 8190826712..b8d120550a 100644 --- a/hls4ml/model/hls_layers.py +++ b/hls4ml/model/hls_layers.py @@ -2309,7 +2309,7 @@ def _check_inputs(self): #expected outputs: node_update assert(len(self.outputs)==1) -class Aggregate(Layer): +class EdgeAggregate(Layer): def initialize(self): self.n_node = self.attributes['n_node'] self.n_edge = self.attributes['n_edge'] @@ -2399,7 +2399,7 @@ def _config_misc(self): aggr_params = self.get_Aggregate_params() nested_duplicate = self._config_template.format(**aggr_params).split('\n') - nested_duplicate[0] = "struct nested_duplicate: nnet::aggregate_config{" + nested_duplicate[0] = "struct nested_duplicate: nnet::edge_aggregate_config{" nested_duplicate = '\n'.join(nested_duplicate) configs['nested_duplicate'] = nested_duplicate @@ -2464,7 +2464,7 @@ def _check_inputs(self): 'GarNetStack' : GarNetStack, 'EdgeBlock' : EdgeBlock, 'NodeBlock' : NodeBlock, - 'Aggregate' : Aggregate, + 'EdgeAggregate' : EdgeAggregate, # TensorFlow-specific layers: 'BiasAdd' : BiasAdd, } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h index c062443b21..f618485b99 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h @@ -40,7 +40,7 @@ namespace nnet { static const bool no_aggr = false; //if no_aggr==true, then skip aggregation steps }; - struct aggregate_config + struct edge_aggregate_config { typedef float table_t; static const unsigned n_node = 10; @@ -185,7 +185,7 @@ namespace nnet { } template - void aggregate( + void edge_aggregate( data_T edge_attr_1D[CONFIG_T::n_edge*CONFIG_T::edge_dim], index_T edge_index_1D[CONFIG_T::n_edge*2], res_T edge_attr_aggr_1D[CONFIG_T::n_node*CONFIG_T::edge_dim]) diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py index f91b951618..a076aaa51e 100644 --- a/hls4ml/templates/vivado_template.py +++ b/hls4ml/templates/vivado_template.py @@ -379,7 +379,7 @@ static const bool io_stream = false; }};""" -aggregate_config_template = """struct aggregation_config{index}: nnet::aggregate_config{{ +edge_aggregate_config_template = """struct aggregation_config{index}: nnet::edge_aggregate_config{{ typedef {table_t} table_t; static const unsigned n_node = {n_node}; static const unsigned n_edge = {n_edge}; @@ -414,7 +414,7 @@ garnet_stack_function_template = 'nnet::garnet_stack<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});' edgeblock_function_template = 'nnet::edgeblock<{input_t}, {index_t}, {output_t}, {config}>({node_attr}, {edge_attr}, {edge_index}, {out}, {w0}, {b0}, {w1}, {b1}, {w2}, {b2}, {w3}, {b3});' nodeblock_function_template = 'nnet::nodeblock<{input_t}, {output_t}, {config}>({node_attr}, {edge_attr_aggr}, {out}, {w0}, {b0}, {w1}, {b1}, {w2}, {b2}, {w3}, {b3});' -aggregate_function_template = 'nnet::aggregate<{input_t}, {index_t}, {output_t}, {config}>({edge_attr}, {edge_index}, {out});' +edge_aggregate_function_template = 'nnet::edge_aggregate<{input_t}, {index_t}, {output_t}, {config}>({edge_attr}, {edge_index}, {out});' dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h'] batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] @@ -443,7 +443,7 @@ 'nnet_utils/nnet_graph.h', 'nnet_utils/nnet_merge.h', 'nnet_utils/nnet_array.h'] -aggregate_include_list = ['nnet_utils/nnet_graph.h'] +edge_aggregate_include_list = ['nnet_utils/nnet_graph.h'] class VivadoBackend(Backend): def __init__(self): @@ -476,7 +476,7 @@ def __init__(self): self.register_templates('GarNetStack' , garnet_stack_function_template,garnet_stack_config_template, garnet_include_list) self.register_templates('EdgeBlock' , edgeblock_function_template, edgeblock_config_template, edgeblock_include_list) self.register_templates('NodeBlock' , nodeblock_function_template, nodeblock_config_template, nodeblock_include_list) - self.register_templates('Aggregate' , aggregate_function_template, aggregate_config_template, aggregate_include_list) + self.register_templates('EdgeAggregate' , edge_aggregate_function_template, edge_aggregate_config_template, edge_aggregate_include_list) def get_valid_reuse_factors(self, layer): n_in = 0 From fc8c9e29c3ce0214b30afbc44e72c6b6df30dcc2 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Tue, 24 Aug 2021 15:49:18 -0400 Subject: [PATCH 07/13] wrong array name in #pragma HLS ARRAY_PARTITION --- hls4ml/templates/vivado/nnet_utils/nnet_graph.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h index f618485b99..3dd6fb0030 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h @@ -216,7 +216,7 @@ namespace nnet { //4. edge_attr_aggr (output) res_T edge_attr_aggr[CONFIG_T::n_node][CONFIG_T::edge_dim]; - #pragma HLS ARRAY_PARTITION variable=edge_update_aggr complete dim=0 + #pragma HLS ARRAY_PARTITION variable=edge_attr_aggr complete dim=0 if((CONFIG_T::aggr==aggr_sum)||(CONFIG_T::aggr==aggr_mean)){ for(int i=0; i < CONFIG_T::n_node; i++){ for(int j=0; j max - if(CONFIG_T::aggr == aggr_max){ //note: the edge_update_aggr array has been initialized but IS NOT ZEROS + if(CONFIG_T::aggr == aggr_max){ //note: the edge_attr_aggr array has been initialized but IS NOT ZEROS for(int i=0; i < CONFIG_T::n_node; i++){ for(int j=0; j Date: Wed, 1 Sep 2021 17:15:37 -0400 Subject: [PATCH 08/13] main difference between this branch and 'pyg_to_hls_rebase' --- .../templates/vivado/nnet_utils/nnet_graph.h | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h index 3dd6fb0030..daffcb94f7 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h @@ -325,10 +325,7 @@ namespace nnet { #pragma HLS ARRAY_PARTITION variable=edge_index complete dim=0 nnet::vec_to_mat(edge_index_1D, edge_index); - // 4. edge_update (output) - res_T edge_update[CONFIG_T::n_edge][CONFIG_T::out_dim]; - #pragma HLS ARRAY_PARTITION variable=edge_update complete dim=0 - + // 4. phi_input (intermediate) int sender_col; int receiver_col; if(CONFIG_T::flow == source_to_target){ @@ -339,35 +336,44 @@ namespace nnet { sender_col = 1; receiver_col = 0; } + data_T phi_input[CONFIG_T::n_edge][2*CONFIG_T::node_dim+CONFIG_T::n_edge]; + #pragma HLS ARRAY_PARTITION variable=phi_input complete dim=0 + for(int i=0; i + for(int j=0; j - data_T node_concat[2*CONFIG_T::node_dim]; - #pragma HLS ARRAY_PARTITION variable=node_concat complete dim=0 - nnet::concatenate1d(node_attr[r], node_attr[s], node_concat); - data_T phi_input[CONFIG_T::edge_dim + 2*CONFIG_T::node_dim]; - #pragma HLS ARRAY_PARTITION variable=phi_input complete dim=0 - nnet::concatenate1d(node_concat, edge_attr[i], phi_input); - - // send it through NN + // send phi_input[i] through NN to edge_update[i] if(CONFIG_T::n_layers == 1){ - nnet::dense_mult_1lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0); + nnet::dense_mult_1lyr(phi_input[i], edge_update[i], core_edge_w0, core_edge_b0); } else if(CONFIG_T::n_layers == 2){ - nnet::dense_mult_2lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1); + nnet::dense_mult_2lyr(phi_input[i], edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1); } else if(CONFIG_T::n_layers == 3){ - nnet::dense_mult_3lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1, core_edge_w2, core_edge_b2); + nnet::dense_mult_3lyr(phi_input[i], edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1, core_edge_w2, core_edge_b2); } else if(CONFIG_T::n_layers == 4){ - nnet::dense_mult_4lyr(phi_input, edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1, core_edge_w2, core_edge_b2, core_edge_w3, core_edge_b3); + nnet::dense_mult_4lyr(phi_input[i], edge_update[i], core_edge_w0, core_edge_b0, core_edge_w1, core_edge_b1, core_edge_w2, core_edge_b2, core_edge_w3, core_edge_b3); } } From 9bf40b91ede673b06188fb85f402487811ac2156 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Wed, 1 Sep 2021 17:40:45 -0400 Subject: [PATCH 09/13] dict no longer necessary (used in development) --- hls4ml/converters/pyg/interaction_network_blocks.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/hls4ml/converters/pyg/interaction_network_blocks.py b/hls4ml/converters/pyg/interaction_network_blocks.py index c45d3af5b1..38f88375da 100644 --- a/hls4ml/converters/pyg/interaction_network_blocks.py +++ b/hls4ml/converters/pyg/interaction_network_blocks.py @@ -57,10 +57,4 @@ def parse_EdgeAggregate(block_name, config, update_dict, index, n_node, n_edge, "inputs": [update_dict["last_edge_update"], "edge_index"], "outputs": [f"layer{index}_out"]} update_dict["last_edge_aggr_update"] = f"layer{index}_out" - return layer_dict, update_dict - -IN_handlers = { - "NodeBlock": parse_NodeBlock, - "EdgeBlock": parse_EdgeBlock, - "EdgeAggregate": parse_EdgeAggregate -} \ No newline at end of file + return layer_dict, update_dict \ No newline at end of file From 18bb3b2cbd7b6bd5db9f2f888d21d5e855a232ac Mon Sep 17 00:00:00 2001 From: abdelabd Date: Wed, 1 Sep 2021 17:40:45 -0400 Subject: [PATCH 10/13] dict no longer necessary (used in development) --- hls4ml/converters/pyg/interaction_network_blocks.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/hls4ml/converters/pyg/interaction_network_blocks.py b/hls4ml/converters/pyg/interaction_network_blocks.py index c45d3af5b1..38f88375da 100644 --- a/hls4ml/converters/pyg/interaction_network_blocks.py +++ b/hls4ml/converters/pyg/interaction_network_blocks.py @@ -57,10 +57,4 @@ def parse_EdgeAggregate(block_name, config, update_dict, index, n_node, n_edge, "inputs": [update_dict["last_edge_update"], "edge_index"], "outputs": [f"layer{index}_out"]} update_dict["last_edge_aggr_update"] = f"layer{index}_out" - return layer_dict, update_dict - -IN_handlers = { - "NodeBlock": parse_NodeBlock, - "EdgeBlock": parse_EdgeBlock, - "EdgeAggregate": parse_EdgeAggregate -} \ No newline at end of file + return layer_dict, update_dict \ No newline at end of file From b8f723abe90f7a8b40b737d2731502fe1d77b2b1 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Fri, 3 Sep 2021 15:16:01 -0400 Subject: [PATCH 11/13] no longer initialize edge_index, just use edge_index_1D directly --- .../templates/vivado/nnet_utils/nnet_graph.h | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h index daffcb94f7..3aa0a41231 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_graph.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_graph.h @@ -196,12 +196,7 @@ namespace nnet { #pragma HLS ARRAY_PARTITION variable=edge_attr complete dim=0 nnet::vec_to_mat(edge_attr_1D, edge_attr); - // 2. edge_index (input) - index_T edge_index[CONFIG_T::n_edge][2]; - #pragma HLS ARRAY_PARTITION variable=edge_index complete dim=0 - nnet::vec_to_mat(edge_index_1D, edge_index); - - //3. num_edge_per_node (intermediate), 4. edge_aggr_mask (intermediate) + //2. num_edge_per_node (intermediate), 3. edge_aggr_mask (intermediate) index_T num_edge_per_node[CONFIG_T::n_node]; #pragma HLS ARRAY_PARTITION variable=num_edge_per_node complete dim=0 ap_uint<1> edge_aggr_mask[CONFIG_T::n_node]; @@ -250,7 +245,7 @@ namespace nnet { #pragma HLS PIPELINE II=CONFIG_T::reuse_factor for(int i=0; i(edge_attr_1D, edge_attr); - // 3. edge_index (input) - index_T edge_index[CONFIG_T::n_edge][2]; - #pragma HLS ARRAY_PARTITION variable=edge_index complete dim=0 - nnet::vec_to_mat(edge_index_1D, edge_index); - - // 4. phi_input (intermediate) + // 3. phi_input (intermediate) int sender_col; int receiver_col; if(CONFIG_T::flow == source_to_target){ @@ -339,8 +329,8 @@ namespace nnet { data_T phi_input[CONFIG_T::n_edge][2*CONFIG_T::node_dim+CONFIG_T::n_edge]; #pragma HLS ARRAY_PARTITION variable=phi_input complete dim=0 for(int i=0; i for(int j=0; j Date: Sat, 4 Sep 2021 01:13:12 -0400 Subject: [PATCH 12/13] added docstring to convert_from_pyg_model() --- hls4ml/converters/__init__.py | 95 +++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 52fd63fe2f..cbdd133671 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -282,6 +282,101 @@ def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim, output_dir='my-hls-test', project_name='myproject', fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}): check_forward_dict(model, forward_dictionary) + """ + + Convert a Pytorch.Geometric model to an hls model. + + Parameters + ---------- + model : Pytorch.geometric model object. + Model to be converted to hls model object. + n_node, n_edge: int, int + These parameters define the size of the graphs that your hls GNN + accepts as input. Inputs must be truncated or zero-padded to this + size before feeding them to your model. This is necessary because + each layer of the hls/hardware implementation has a fixed size + and cannot be resized. + node_dim, edge_dim: int, int + node_dim defines the length of the vector used to represent each + node in the graph-input. For example, if each node is represented + as a 1x3 vector, node_dim=3. + Likewise, edge_dim defines the length of the vector used to + represent each edge in the graph-input. + + forward_dictionary: OrderedDict object of the form {string: string} + Use this dictionary to define the order in which your model's + forward() method calls on the model's submodules. The keys + of the dictionary should be the names of your model's submodules, and the + value stored in each key should indicate whether that submodule is an + 'EdgeBlock' (i.e. it predicts messages/edge-updates) or whether its a + 'NodeBlock' (i.e. it predicts node-updates). + + For example, consider this InteractionNetwork (https://github.com/GageDeZoort/interaction_network_paper/blob/pytorch_geometric/models/interaction_network.py), + whose forward() method calls on its submodules in the following order: + 1. An EdgeBlock named 'R1' + 2. A NodeBlock named 'O' + 3. An EdgeBlock named 'R2' + + One would define its forward dictionary as such: + >>> forward_dictionary = OrderedDict() + >>> forward_dictionary['R1'] = 'EdgeBlock' + >>> forward_dictionary['O'] = 'NodeBlock' + >>> forward_dictionary['R2'] = 'EdgeBlock' + + It is really important to define the submodules in the same order with which the + forward() method calls on them. hls4ml has no other way of inferring this order. + + activate_final: string, optional + If the activation of the final output is not already a layer in the corresponding + submodule, name the type of the activation function here. In the preceding example, + one would pass the value 'sigmoid', because the final output of the model + is the sigmoid-activated output of 'R2' (the last submodule called by the + forward() method). In other words, the model returns torch.sigmoid(self.R2(m2)). + Other accepted values for this parameter include: + ['linear', 'relu', 'elu', 'selu', 'prelu', 'leaky_relu', 'softmax', 'tanh', 'softplus', + 'softsign', 'hard_sigmoid','thresholded_relu', 'binary_tanh', 'ternary_tanh'] + output_dir : string, optional + Output directory to write hls codes. + project_name : string, optional + hls project name. + fpga_part : string, optional + The particular FPGA part number that you are considering. + clock_period : int, optional + The clock period, in ns, at which your algorithm runs. + io_type : string, optional + Your options are 'io_parallel' or 'io_serial' where this really + defines if you are pipelining your algorithm or not. + hls_config : dict, optional + Additional configuration dictionary for hls model. + + Returns + ------- + hls_model : hls4ml model object. + + See Also + -------- + hls4ml.convert_from_pytorch_model, hls4ml.convert_from_keras_model, + hls4ml.convert_from_onnx_model + + Example + -------- + >>> import hls4ml + >>> config = hls4ml.utils.config_from_pyg_model(model, granularity='model') + >>> + >>> forward_dictionary = OrderedDict() + >>> forward_dictionary['R1'] = 'EdgeBlock' + >>> forward_dictionary['O'] = 'NodeBlock' + >>> forward_dictionary['R2'] = 'EdgeBlock' + >>> n_node, node_dim = 112, 3 + >>> n_edge, edge_dim = 148, 4 + >>> hls_model = hls4ml.converters.convert_from_pyg_model(model, n_node, node_dim, + n_edge, edge_dim, + forward_dictionary, + activate_final='sigmoid' + hls_config=config) + + """ + config = create_vivado_config( output_dir=output_dir, project_name=project_name, From 9850455cc36e0f4c1248d3baf1029c234a2ec2a8 Mon Sep 17 00:00:00 2001 From: abdelabd Date: Sat, 4 Sep 2021 01:27:38 -0400 Subject: [PATCH 13/13] aesthetics --- hls4ml/converters/__init__.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index cbdd133671..a1e40ce1aa 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -277,8 +277,8 @@ def check_forward_dict(model, forward_dictionary): except AttributeError: raise AttributeError(f'Model is missing module "{key}" that is present in the provided forward dictionary; Check compatability') -def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim, - forward_dictionary=None, activate_final=None, +def convert_from_pyg_model(model, forward_dictionary, n_node, node_dim, + n_edge, edge_dim, activate_final=None, output_dir='my-hls-test', project_name='myproject', fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}): check_forward_dict(model, forward_dictionary) @@ -367,11 +367,9 @@ def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim, >>> forward_dictionary['R1'] = 'EdgeBlock' >>> forward_dictionary['O'] = 'NodeBlock' >>> forward_dictionary['R2'] = 'EdgeBlock' - >>> n_node, node_dim = 112, 3 - >>> n_edge, edge_dim = 148, 4 - >>> hls_model = hls4ml.converters.convert_from_pyg_model(model, n_node, node_dim, - n_edge, edge_dim, - forward_dictionary, + >>> graph_dimensions = {"n_node": 112, "node_dim": 3, "n_edge": 148, "edge_dim": 4} + >>> hls_model = hls4ml.converters.convert_from_pyg_model(model, forward_dictionary, + **graph_dimensions, activate_final='sigmoid' hls_config=config)