Skip to content

PyTorch.Geometric to HLS4ML #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed

PyTorch.Geometric to HLS4ML #379

wants to merge 25 commits into from

Conversation

abdelabd
Copy link

@abdelabd abdelabd commented Aug 17, 2021

Purpose

Add pyg_to_hls converter to parse an Interaction Network PyTorch Geometric model which is composed of alternating EdgeBlocks (which applies a neural network to "edge features" to compute messages), Aggregate layers (which aggregates edge features for nodes that are connected to it), and NodeBlocks (which applies a neural network aggregated edge features.

PyTorch Geometric models expect input data consisting of node attributes node_attr of size [n_nodes, n_node_features], edge attributes edge_attr of size [n_edges, n_edge_features], and an edge index edge_index of size [2, n_edges], which specifies the data's connectivity.

The currently supported network for charged particle tracking output edge weights of size [n_edges, 1], which classify the edges as valid (part of a good track segment) or not.

Slides

https://docs.google.com/presentation/d/1C3EuVq4FaaayNSBeaRBHyHnWcujLbXASxAR2cqQyLVo/edit?usp=sharing

Test script

Notebook to test the conversion: https://github.com/abdelabd/manual_GNN_conversion/blob/main/pyg_to_hls_walkthrough.ipynb
Script to test the conversion: https://github.com/abdelabd/manual_GNN_conversion/blob/main/convert_model.py

python convert_model.py conversion_config.yaml --n-graphs=100 --aggregation add --flow source_to_target --precision 'ap_fixed<16,8> --max-nodes=28 --max-edges=37 --n-neurons=8

or more fully featured code: https://github.com/abdelabd/manual_GNN_conversion

python test_model.py test_config.yaml --n-graphs 100 --aggregation add --flow source_to_target --precision 'ap_fixed<16,8>' --max-nodes 28 --max-edges 37 --n-neurons 8 --synth

Changes

Frontend

  • added onto hls4ml.converters.__init__.py:
    • added from hls4ml.converters.pyg_to_hls import pyg_to_hls
    • added convert_from_pyg_model() function
    • registered the pyg block_handlers
  • added hls4ml/converters/pyg/__init__.py
    • nothing here
  • added hls4ml/converters/pyg/interaction_network_blocks.py
    • has the pyg block_handlers (analogous to layer_handlers) for NodeBlock, EdgeBlock, Aggregate
  • added hls4ml.converters.pyg_to_hls.py
    • has the class PygModelReader(PyTorchModelReader)
    • has the function pyg_to_hls()
    • has the pyg 'block_handlers'
  • added onto hls4ml.utils.config.py
    • added config_from_pyg_model() function
  • changed hls4ml.writer.vivado_writer.py
    • changed VivadoWriter.write_defines() method
      • if certain macros are shared between different layers of an HLSModel, then these macros are not written multiple times in the defines.h file.
    • changed VivadoWriter.write_yml() method
      • If yaml.dump() has trouble writing a torch model with submodules composed of torch models themselves, then it writes the location of the model's state dictionary instead.
  • added onto hls4ml.model.hls_layers
    • added class GraphBlock(Layer): parent class for EdgeBlock and NodeBlock. Allows hls4ml to handle an EdgeBlock or NodeBlock as if it were a single layer (like Dense, Convolution, etc.), rather than a submodule composed of several layers.
    • added class EdgeBlock(GraphBlock)
    • added class NodeBlock(GraphBlock)
    • added class Aggregate(Layer)
  • hls4ml.model.hls_model
    • changed HLSModel.get_weights_data()
      • handling for getting weights data if the layer in question doesn't belong directly to the torch model in question, but rather to a submodule of the torch model such as an EdgeBlock or NodeBlock

Backend

  • added onto nnet_utils/nnet_array.h
    • added struct matrix_config which specifies parameters for conversion functions
    • added vec_to_mat() function to convert from 1D unrolled arrays and 2D matrices
    • added mat_to_vec() function to convert from 2D matrices to 1D unrolled arrays
  • changed nnet_utils/nnet_dense.h and changed nnet_utils/nnet_dense_resource.h
    • added parameter "static const bool remove_pipeline_pragma = false;" to the dense_config
    • under the ReuseLoop for the dense_resource_rf_leq_nin(), dense_resource_rf_gt_nin_rem0(), and dense_resource_rf_gt_nin() functions: "#pragma HLS PIPELINE II=1 rewind" --> if remove_pipeline_pragma=false then #pragma HLS PIPELINE II=1 rewind.
      • This change allows the pipelining to happen at the GraphBlock layer level instead of the individual dense layer level
  • added nnet_utils/nnet_graph.h
    • has the backend code for the EdgeBlock, NodeBlock, and Aggregate layers.

@abdelabd abdelabd changed the title Pyg to hls rebase PyTorch.Geometric to HLS4ML Aug 17, 2021
@jmduarte jmduarte requested review from yiiyama and Duchstf August 17, 2021 21:52
@jmduarte jmduarte mentioned this pull request Aug 20, 2021
Copy link
Contributor

@yiiyama yiiyama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking about nnet_graph.h because that is the main contribution here:
(If I understood the intention correctly) I think the idea of partitioning the node and edge arrays completely to enable parallel random access read is good. I do wonder about scalability though; how large of a graph does the implementation support?
Also, I think there are arrays that don't need to be partitioned and copied, unless you really do intend to pipeline the entire functions of aggregate, edgeblock, and nodeblock. If the pipelining was only for the main loop (over edges or nodes), then any array that is accessed serially doesn't need partitioning over that dimension (you'd still need to break up the feature dimension though).

def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim,
forward_dictionary=None, activate_final=None,
output_dir='my-hls-test', project_name='myproject',
fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add a docstring. Especially, non-specialists won't know what forward_dictionary is supposed to be.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe check_forward_dict doesn't need to be a function because it's anyway called once?

for key in forward_dictionary:
    if not hasattr(model, key):
        raise AttributeError(f'Model is missing module "{key}" that is present in the provided forward dictionary; ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I guess the issue forward_dictionary tries to solve is general - in pytorch_to_hls the "layers" are assumed to be in the order appearing in model.named_modules(). @vloncar @Duchstf Should the general pytorch_to_hls also consider accepting an OrderedDict ({module name: hls4ml layer class name}) that specifies the order that a pytorch model submodules should be converted into hls4ml layers?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we couldn't figure out a way around using the forward_dict, short of parsing the pytorch model's forward() method... Open to any suggestions though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this is a known limitation for the pytorch converter so far. I was thinking of solving this by using PytorchScript and utilize the model's graph the same way we do with the ONNX converter but never get around to work on it.

GNNs so far might not be jittable though, so I guess what @abdelabd is doing makes sense.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring: abdelabd@d2831e2

Copy link
Member

@jmduarte jmduarte Oct 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I recently created a jittable version of this model: GageDeZoort/interaction_network_paper#6

Assuming the model state dictionary for the current model is saved in model.pt, you can load it and re-save it like this:

device = 'cpu'
model_state_dict = 'model.pt'
model = InteractionNetwork(hidden_size=8)
model = model.to(device)
model.load_state_dict(torch.load(model_state_dict))
torch.jit.save(model, model_state_dict.replace('.pt', 'cpu_jit.pt'))

@Duchstf How would you then analyze the model's graph to get the order of the modules?

@abdelabd maybe you can give it a shot and see if the order is derivable?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will do!

receiver_col = 0;
}

#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you want this pragma to be inside for(int i=0; i<CONFIG_T::n_edge; i++){? Maybe I'm misunderstanding your intention / how pipeline pragmas work..

EDIT: Ah you indeed wanted to pipeline the entire function - then I suggest moving the pragma to the top of the function, just for readability.

RE-EDIT: Well actually no, unrolling the n_edge loop doesn't sound feasible. I don't think you should be pipelining the function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tested a few different implementations of this using the basic configuration (--max-nodes=28 --max-edges=51 --precision=ap_fixed<16,8> --reuse=1):

  1. Removing the function pipeline pragma
  2. Removing all the unroll pragmas
  3. Removing the function pipeline pragma and all the unroll pragmas
  4. Pipelining the main edge-loop, removing the unroll pragmas from the main edge-loop, and removing the function pipeline pragma

These all yielded identical syntheses as the baseline (what we currently have)... which one should we go with?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant to say was that (I think) having the PIPELINE pragma here will cause the nodes loop to be fully unrolled, regardless of the reuse_factor. But you'd typically have a few tens-hundreds of nodes for a graph network problem to be useful, so unrolling the nodes loop will cause an explosion of resource usage.
But then maybe I'm missing some crucial point here. Should we have a quick chat so that I can give more useful comments?


#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
for(int i=0; i<CONFIG_T::n_edge; i++){
#pragma HLS UNROLL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All loops within a pipelined block are automatically unrolled, so you don't need this and the two other pragmas below.

// 1. edge_attr (input)
data_T edge_attr[CONFIG_T::n_edge][CONFIG_T::edge_dim];
#pragma HLS ARRAY_PARTITION variable=edge_attr complete dim=0
nnet::vec_to_mat<data_T, data_T, typename CONFIG_T::edge_attr_config>(edge_attr_1D, edge_attr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for vec_to_mat and mat_to_vec operations? What speedup can we get by doing this instead of referencing 1D arrays through simple index calculations (e.g. index = i * CONFIG_T::edge_dim + j)? Booking new arrays and just copying values into them don't seem resource- and time-efficient to me..

EDIT: I understand now - you want to use ARRAY_PARTITION complete to work around the array access limitation. But that's going to use an enormous amount of registers..

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did implement and compare a 1D version of nnet_graph, and saw that latency/usage was about equivalent. The 1D implementation required some extra indexing logic and multiplications, but the biggest resource drain were the necessary intermediate/clone products for routing the inputs and outputs of the dense-networks. Granted, the testing/comparison did not extend to very large graphs (just 28 nodes by 37 edges)

yaml.dump(model.config.config, file)
try:
yaml.dump(model.config.config, file)
except ValueError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried your conversion script, I got a TypeError here instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting... would you mind posting the traceback?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

root@dd3c9b3b218f:/workspace/manual_GNN_conversion# python test_model.py test_config.yaml --n-graphs 100 --aggregation add --flow source_to_target --precision 'ap_fixed<16,8>' --max-nodes 28 --max-edges 37 --n-neurons 8
/opt/conda/lib/python3.8/site-packages/hls4ml/converters/__init__.py:31: UserWarning: WARNING: Tensorflow converter is not enabled!
  warnings.warn("WARNING: Tensorflow converter is not enabled!")
n_graphs: 2
writing test bench data for 1st graph
Writing HLS project
Traceback (most recent call last):
  File "test_model.py", line 239, in <module>
    main()
  File "test_model.py", line 144, in main
    torch_model, hls_model, torch_wrapper = load_models(config['trained_model_dir'], graph_dims, aggr=a, flow=f, n_neurons=nn, precision=args.precision, output_dir=args.output_dir, reuse=args.reuse)
  File "test_model.py", line 115, in load_models
    hls_model.compile()
  File "/opt/conda/lib/python3.8/site-packages/hls4ml/model/hls_model.py", line 521, in compile
    self.write()
  File "/opt/conda/lib/python3.8/site-packages/hls4ml/model/hls_model.py", line 518, in write
    self.config.writer.write_hls(self)
  File "/opt/conda/lib/python3.8/site-packages/hls4ml/writer/vivado_writer.py", line 692, in write_hls
    self.write_yml(model)
  File "/opt/conda/lib/python3.8/site-packages/hls4ml/writer/vivado_writer.py", line 664, in write_yml
    yaml.dump(model.config.config, file)
  File "/opt/conda/lib/python3.8/site-packages/yaml/__init__.py", line 290, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "/opt/conda/lib/python3.8/site-packages/yaml/__init__.py", line 278, in dump_all
    dumper.represent(data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 27, in represent
    node = self.represent_data(data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 342, in represent_object
    return self.represent_mapping(
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 342, in represent_object
    return self.represent_mapping(
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 364, in represent_ordered_dict
    return self.represent_sequence(tag, [items])
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 356, in represent_object
    return self.represent_mapping(tag+function_name, value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/opt/conda/lib/python3.8/site-packages/yaml/representer.py", line 317, in represent_object
    reduce = data.__reduce_ex__(2)
TypeError: __reduce_ex__() missing 1 required positional argument: 'proto'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiiyama what version of PyTorch are you using? I get the same error as you with PyTorch 1.9.0, but no error with PyTorch 1.7.1 like in this environment: https://github.com/jmduarte/manual_GNN_conversion/blob/main/pyg_to_hls_env.yml

In any case, @abdelabd, we should solve this in a more elegant way that works for both of these PyTorch versions. Perhaps like the keras model representer approach above it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using 1.9.0 too. Indeed I was going to suggest the same thing. Not that I know how to write this representer thing, but I'm sure there is a way to do it!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had trouble reproducing the error but I think this should fix it? abdelabd@7df2014

def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim,
forward_dictionary=None, activate_final=None,
output_dir='my-hls-test', project_name='myproject',
fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this is a known limitation for the pytorch converter so far. I was thinking of solving this by using PytorchScript and utilize the model's graph the same way we do with the ONNX converter but never get around to work on it.

GNNs so far might not be jittable though, so I guess what @abdelabd is doing makes sense.

@jmduarte jmduarte requested review from yiiyama and Duchstf September 18, 2021 19:13
@jmduarte
Copy link
Member

@yiiyama @Duchstf can you re-review? thanks!

@jmduarte jmduarte self-requested a review September 21, 2021 01:38
Copy link
Contributor

@thesps thesps left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some specific comments. General comment: we need to take some care with those "deep" changes that touch everything (e.g. in Layer, dense and elsewhere) that they are really needed and general.

And then I wonder about how flexible the GNN support is? For example I saw the places where there are some explicit choices for the number of layers in the blocks. Can that be generalised?

Could we get a small pytest? We can install torch_geometric on the CI image.

Comment on lines +410 to +423
// send it through NN
if(CONFIG_T::n_layers == 1){
nnet::dense_mult_1lyr<data_T, res_T, CONFIG_T>(phi_input, node_update[i], core_node_w0, core_node_b0);
}
else if(CONFIG_T::n_layers == 2){
nnet::dense_mult_2lyr<data_T, res_T, CONFIG_T>(phi_input, node_update[i], core_node_w0, core_node_b0, core_node_w1, core_node_b1);
}
else if(CONFIG_T::n_layers == 3){
nnet::dense_mult_3lyr<data_T, res_T, CONFIG_T>(phi_input, node_update[i], core_node_w0, core_node_b0, core_node_w1, core_node_b1, core_node_w2, core_node_b2);
}
else { // CONFIG_T::n_layers == 4
nnet::dense_mult_4lyr<data_T, res_T, CONFIG_T>(phi_input, node_update[i], core_node_w0, core_node_b0, core_node_w1, core_node_b1, core_node_w2, core_node_b2, core_node_w3, core_node_b3);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be generalised? Can we construct the block architecture dynamically?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thesps do you mean something similar to @vloncar's code generation for CNNs: master...vloncar:cnn_parallel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be made as nnet::dense_mult<data_T, res_T, CONFIG_T, N_LAYERS>(phi_input, node_update, core_node_w[N_LAYERS], core_node_b[N_LAYERS])? Then you can make template specializations for nnet::dense_mult<data_T, res_T, CONFIG_T, 1>, nnet::dense_mult<data_T, res_T, CONFIG_T, 2> etc and create a general one for N_LAYERS > 4. (If you want to be really fancy, you can even force the compiler to emit a warning if more than 4 layers are used.) In these specializations, you'll be using hardcoded N_LAYER index in weights and bias arrays, so Vivado HLS may still be able to create optimal designs. The other approach that comes to mind is variadic functions, but when I played with them in 2017.2 the compiler produced awful designs.

Copy link
Contributor

@thesps thesps Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking generally that these EdgeBlock and NodeBlock are NNs that can have any architecture. But this code hardcodes a few special cases. Those special cases might be the most common ones, but can't we generalise? e.g. is there any issue with using 5 layers, or an activation other than ReLU?

I'm not sure what's the best way forward, but it may make more sense if the EdgeBlock and NodeBlock are separate HLSModels (now ModelGraphs), with the full flexibility over NN architecture and configuration that that implies. In this PR those blocks are 'Layers' but they are actually full NN models - of the types we already support.

Edit: Can we have (or develop) a ModelGraph with a 'Layer' that is another ModelGraph?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vloncar I was also thinking this. The only thing is that the core_node_w and core_node_b inputs would have to be indexed in a complicated way, because each of the "layers" in the generalized MLP may not be of the same size. It should be possible to pass some parameter detailing each layer-size, and then cut up the overall input into the separate layers.

@thesps The NodeBlock does indeed just pass the input through the relevant NN, but there is a bit more logic to the EdgeBlock than just the NN (i.e. constructing the NN-input). It may be possible to construct the NN for each "Block" as its own external function, and then call it inside of the block. The main challenge is to make each block calls its own, unique NN-function. I think we can pass each NN-function as a parameter to its respective block, but I'm not sure if Vivado would be able to synthesize this.

One way around this is to basically split each Block into two functions - one for the Block-specific logic, and one for the NN logic - and then call those two functions in succession. Either way, I think this would require adding a new method: VivadoWriter.write_NNs()

I imagine that all of these options will affect the usage/latency numbers, so I'm currently trying to test them all out and see which is the fastest/cheapest.

Comment on lines 471 to 482
if module_name is not None:
return self.reader.get_weights_data(layer_name, var_name, module_name)
else:
return self.reader.get_weights_data(layer_name, var_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be handled at the level of the Reader instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes here: abdelabd@ca49221

I made the changes to the PyTorchModelReader.get_weights_data() because this seemed like a light change, and the PygModelReader inherits this method. If you want, I can just implement this as a method unique to PygModelReader instead.

Comment on lines +2035 to +2065
if linear_count < 4:
for i in range(linear_count + 1, 5):
linear_config_i = linear_config.split('\n')
linear_config_i[0] = re.sub(f"dense_config{linear_count}", f"dense_config{i}", linear_config_i[0])
linear_config_i = "\n".join(linear_config_i)
configs[f"dense_config{i}"] = linear_config_i

if relu_count < 4:
for i in range(relu_count + 1, 5):
relu_config_i = relu_config.split('\n')
relu_config_i[0] = re.sub(f"relu_config{relu_count}", f"relu_config{i}", relu_config_i[0])
relu_config_i = '\n'.join(relu_config_i)
configs[f"relu_config{i}"] = relu_config_i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a similar comment on the nnet_graph.h, but can the architecture be generalised?

Comment on lines 563 to 570
try:
dims = self.attributes['dim_names']
except KeyError:
dims = ['N_INPUT_{}_{}'.format(i, self.index) for i in range(1, len(shape) + 1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be handled at a higher level? i.e. not having to access the shape differently for pyg vs the rest?

Copy link
Author

@abdelabd abdelabd Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out the special pyg-handling wasn't necessary in the first place: abdelabd@5c36aad


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be cleaned up?

Comment on lines 666 to 669
import torch
model_path = model.config.get_output_dir() + "/torch_model_state_dict.pt"
torch.save(model.config.config["PytorchModel"].state_dict(), model_path)
model.config.config["PytorchModel"] = model_path
yaml.dump(model.config.config, file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is safe. If the previous try failed, it doesn't mean the model is a torch model, in the general case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it will skip writing the hls4ml_config.yml file at all, and raise a UserWarning instead: abdelabd@7e7ebf1

As far as I can tell, the file isn't necessary for creating, compiling, or building the HLSModel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The yaml file is very useful for reproducibility and it also allows you to split your conversion into several steps, that you may run at different times on different computers, so I'd say we shouldn't remove the feature. I fully agree that the current way of writing this information in yaml is not the best (the Keras one also has problems), but let's improve on that instead of dropping the feature.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, another easy fix would be to save the YAML file, but to first delete the "PytorchModel" argument from the model.config.config. I'm trying to avoid messing with the yaml.dump method, but if this fix still seems too messy then I guess that can't be avoided.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this: abdelabd@bc93393

Comment on lines +300 to +302
numbers = set('\n'.join(numbers).split('\n')) #we want a unique set of macro declarations, since some of the macros are shared between different NN-blocks
newline += '\n'.join([i for i in numbers if i!=''])
newline += '\n' #for formatting purposes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some macros, such as N_EDGE or N_NODE, are shared between different hls4ml layers. This basically keeps those macros from being defined several times in the cpp header.

abdelabd and others added 16 commits December 26, 2021 21:02
Add files via upload

Update hls_layers.py

Update hls_layers.py

Update hls_model.py

Update hls_model.py

Update vivado_writer.py

Add files via upload

Update vivado_template.py

Update vivado_writer.py

added vec_to_mat, mat_to_vec

Delete nnet_graph.h

Add files via upload

handling for 1D-vector inputs and outputs

Update vivado_writer.py

Update pyg_to_hls.py

Update hls_layers.py

Update vivado_template.py

Update pyg_to_hls.py

Update hls_layers.py

Update nnet_activation.h

Update nnet_dense_resource.h

wip

Add files via upload

Add files via upload

added testbench handling for concatenation

added weights tranpose, deleted hard-coded testbench

add nnet::matrix_config

cleaning up

add HLSModel_GNN._get_top_function(), HLSModel_GNN.predict(..)

top-level conversion function (starting ground)

update naming conventions:'Rn' --> 'node_attr', 'Re' --> 'edge_attr'. update order of inputs: 1.node_attr, 2.edge_attr, 3.edge_index

update naming conventions, edge_index shape (now, edge_index.shape=[N_EDGE, 2])

add max, mean aggregation. update naming conventions

generalize all the aggregation methods in a single function

cleaning up

added precision and testbenching

added precision handle

added handling for different flow direction

cleaning up from 'save intermediates'

slight improvement on max aggregation

fixed max-aggregation

minor updates

generalizing pyg_to_hls()

tidying up

improved handling for user-input precision

re-included extra #pragma HLS UNROLL, don't know if this is correct yet

implemented LUT-division

missed a semicolon

added handling for initial aggregation layer

added Aggregate layer, self._check_inputs()

added Aggregate layer

use existing test bench code

update naming conventions

tidying up, added #pragma HLS UNROLL to nnet_array::vec_to_mat/mat_to_vec

ditched single-edge aggregation functions

re-added single-edge-aggregation functions

re-added #pragma HLS ARRAY_PARTITION

speciy model inputs; parition

merge

cleaning up, update naming conventions

ditched single-edge-aggregation functions, improved sender-index/receiver-index handling

changed 'else if{' to 'else{ if{'

split different aggregation methods into separate functions

max not fully functional yet, just committing changes before switching branch

got max-aggregation LOGIC working, still testing build_prj.tcl

fixed up max-aggregation again
@jmduarte
Copy link
Member

@thesps I've updated the CI image (https://gitlab.cern.ch/fastmachinelearning/hls4ml-testing/-/merge_requests/2) and added a test (jmduarte@cca97c3), following the GarNet example, which @abdelabd can add to this PR soon.

@JanFSchulte
Copy link
Contributor

As this development has become and stale and will at some point be superseeded by different solutions, we are closing this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants