-
Notifications
You must be signed in to change notification settings - Fork 462
Quartus Embedding Layer #548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate | ||
from hls4ml.model.layers import Embedding | ||
|
||
|
||
embed_config_template = """struct config{index} : nnet::embed_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned n_out = {n_out}; | ||
static const unsigned vocab_size = {vocab_size}; | ||
static const unsigned io_type = nnet::{iotype}; | ||
static const unsigned reuse_factor = {reuse}; | ||
typedef {embeddings_t.name} embeddings_t; | ||
}};\n""" | ||
|
||
embed_function_template = 'nnet::embedding<{input_t}, {output_t}, {config}>({input}, {output}, {e});' | ||
|
||
embed_include_list = ['nnet_utils/nnet_embed.h', 'nnet_utils/nnet_embed_stream.h'] | ||
|
||
class EmbeddingConfigTemplate(LayerConfigTemplate): | ||
def __init__(self): | ||
super().__init__(Embedding) | ||
self.template = embed_config_template | ||
|
||
def format(self, node): | ||
params = self._default_config_params(node) | ||
return self.template.format(**params) | ||
|
||
class EmbeddingFunctionTemplate(FunctionCallTemplate): | ||
def __init__(self): | ||
super().__init__(Embedding, include_header=embed_include_list) | ||
self.template = embed_function_template | ||
|
||
def format(self, node): | ||
params = self._default_function_params(node) | ||
params['e'] = node.get_weights('embeddings').name | ||
|
||
return self.template.format(**params) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#ifndef NNET_EMBED_H_ | ||
#define NNET_EMBED_H_ | ||
|
||
#include "nnet_common.h" | ||
#include "nnet_helpers.h" | ||
|
||
namespace nnet { | ||
|
||
struct embed_config { | ||
// Internal data type definitions | ||
typedef float embeddings_t; | ||
|
||
// (Default layer sizes, overwritten form the backend | ||
static const unsigned n_in = 10; | ||
static const unsigned n_out = 16; | ||
static const unsigned vocab_size = 50; | ||
|
||
// Resource reuse info | ||
static const unsigned io_type = io_parallel; | ||
static const unsigned reuse_factor = 1; | ||
}; | ||
|
||
template<class data_T, class res_T, typename CONFIG_T> | ||
void embedding( | ||
data_T data[CONFIG_T::n_in], | ||
res_T res[CONFIG_T::n_in * CONFIG_T::n_out], | ||
const typename CONFIG_T::embeddings_t embeddings[CONFIG_T::vocab_size * CONFIG_T::n_out]) { | ||
|
||
/* | ||
* Can store embeddings[] in a register, but a large multiiplexer | ||
* is created due to a non-constant access pattern | ||
*/ | ||
|
||
InputSequence: | ||
#pragma ii CONFIG_T::reuse_factor | ||
#pragma unroll | ||
for (int j = 0; j < CONFIG_T::n_in; j++) { | ||
DenseEmbedding: | ||
#pragma unroll | ||
for (int i = 0; i < CONFIG_T::n_out; i++) { | ||
res[j * CONFIG_T::n_out + i] = embeddings[data[j].to_uint() * CONFIG_T::n_out + i]; | ||
} | ||
} | ||
} | ||
|
||
} | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
/* | ||
* PLACEHOLDER - The common pass embedding.py includes both parallel and streaming implementations; streaming is currently not supported in Quartus | ||
*/ | ||
|
||
#ifndef NNET_EMBED_STREAM_H_ | ||
#define NNET_EMBED_STREAM_H_ | ||
|
||
namespace nnet {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a stub that raises an exception while we still don't have the implementation? Otherwise the compilation fails with a cryptic message (cannot substitute template argument or something like that). Or does that require the more streaming infrastructure in Quartus (#557)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should never be invoked, as soon as Quartus attempts to use io_stream an exception is thrown from Python (although #557 attempts to add infrastructure for streaming, which means for some layers we should throw exceptions and for some not - this should be handled by Python and can be a part of the streaming PR) |
||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remind me, why is the
to_uint()
needed here?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data[j]
is a fixed point number by default, withoutto_uint()
compilation will fail asac_fixed
andint
are not interchangeable - may be allowed with the-fpermissive
flag, not sure.