Skip to content

Embedding layer #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,18 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader, c
layer['n_filt']=input_shapes[0][3]

return layer, [shape for shape in input_shapes[0]]

@keras_handler('Embedding')
def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader, config):
assert('Embedding' in keras_layer['class_name'])

layer = parse_default_keras_layer(keras_layer, input_names)

weights_shape = data_reader.get_weights_shape(layer['name'], 'embeddings')
layer['n_in'] = input_shapes[0][1]
layer['vocab_size'] = weights_shape[0]
layer['n_out'] = weights_shape[1]
layer['weight_quantizer'] = None
output_shape = input_shapes[0]+[layer['n_out']]

return layer, output_shape
29 changes: 29 additions & 0 deletions hls4ml/model/hls_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,34 @@ def _get_transforms_config(self, params):

params['sublayer_configs'] = '\n'.join(sublayer_configs)

class Embedding(Layer):
def initialize(self):
shape = self.get_input_variable().shape[:]
shape += [self.attributes['n_out']]
if len(shape) > 1:
dims = ['N_LAYER_{}_{}'.format(i, self.index) for i in range(1, len(shape) + 1)]
else:
dims = ['N_LAYER_{}'.format(self.index)]
self.add_output_variable(shape, dims)

data = self.model.get_weights_data(self.name, 'embeddings')
self.add_weights_variable(name='embeddings', var_name='w{index}', data=data, quantizer=self.get_attr('weight_quantizer'))

def function_cpp(self):
params = self._default_function_params()
params['w'] = self.get_weights('embeddings').name

return [self._function_template.format(**params)]

def config_cpp(self):
params = self._default_config_params()
params['n_in'] = self.get_input_variable().size_cpp()
params['n_out'] = self.attributes['n_out']
params['vocab_size'] = self.attributes['vocab_size']
params['weight_t'] = self.get_weights('embeddings').type.name

return self._config_template.format(**params)

layer_map = {
'Input' : Input,
'InputLayer' : Input,
Expand Down Expand Up @@ -1894,6 +1922,7 @@ def _get_transforms_config(self, params):
'Transpose' : Transpose,
'GarNet' : GarNet,
'GarNetStack' : GarNetStack,
'Embedding' : Embedding,
# TensorFlow-specific layers:
'BiasAdd' : BiasAdd,
}
Expand Down
62 changes: 62 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_embed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#ifndef NNET_EMBED_H_
#define NNET_EMBED_H_

#include "nnet_common.h"
#include "nnet_helpers.h"
#include "hls_stream.h"
#include <math.h>

namespace nnet {

struct embed_config
{
// Internal data type definitions
typedef float weight_t;

// Layer Sizes
static const unsigned n_in = 10;
static const unsigned n_out = 16;
static const unsigned vocab_size = 50;

// Resource reuse info
static const unsigned io_type = io_parallel;
static const unsigned reuse_factor = 1;
};

template<class data_T, class res_T, typename CONFIG_T>
void embedding(
data_T data[CONFIG_T::n_in],
res_T res[CONFIG_T::n_in*CONFIG_T::n_out],
typename CONFIG_T::weight_t weights[CONFIG_T::vocab_size*CONFIG_T::n_out])
{
// copy over the corresponding row in the weights lookup table
for (int j = 0; j < CONFIG_T::n_in; j++) {
for (int i = 0; i < CONFIG_T::n_out; i++) {
#pragma HLS UNROLL
res[j * CONFIG_T::n_out + i] = weights[data[j] * CONFIG_T::n_out + i];
}
}
}

template<class data_T, class res_T, typename CONFIG_T>
void embedding(
hls::stream<data_T> &data,
hls::stream<res_T> &res,
typename CONFIG_T::weight_t weights[CONFIG_T::vocab_size*CONFIG_T::n_out])
{
// copy over the corresponding row in the weights lookup table
data_T in_data = data.read();
res_T res_pack;
#pragma HLS PIPELINE
#pragma HLS DATA_PACK variable=res_pack
for (int j = 0; j < data_T::size; j++) {
for (int i = 0; i < CONFIG_T::n_out; i++) {
#pragma HLS UNROLL
res_pack[i] = weights[in_data[j] * CONFIG_T::n_out + i];
}
res.write(res_pack);
}
}
}

#endif
14 changes: 13 additions & 1 deletion hls4ml/templates/vivado_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,15 @@
garnet_stack_config_template = (garnet_stack_base_config_template, garnet_stack_sublayer_config_template)


embed_config_template = """struct config{index} : nnet::embed_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned vocab_size = {vocab_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
typedef {weight_t} weight_t;
}};\n"""


dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
Expand All @@ -366,6 +375,7 @@
transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {config}>({input}, {output});'
garnet_function_template = 'nnet::garnet{impl}<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});'
garnet_stack_function_template = 'nnet::garnet_stack<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});'
embed_function_template = 'nnet::embedding<{input_t}, {output_t}, {config}>({input}, {output}, {w});'

dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h']
batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h']
Expand All @@ -380,6 +390,7 @@
resize_include_list = ['nnet_utils/nnet_image.h', 'nnet_utils/nnet_image_stream.h']
transpose_include_list = ['nnet_utils/nnet_array.h']
garnet_include_list = ['nnet_utils/nnet_garnet.h']
embed_include_list = ['nnet_utils/nnet_embed.h']

class VivadoBackend(Backend):
def __init__(self, name='Vivado'):
Expand Down Expand Up @@ -410,7 +421,8 @@ def __init__(self, name='Vivado'):
self.register_templates('Resize' , resize_function_template, resize_config_template, resize_include_list)
self.register_templates('Transpose' , transpose_function_template, transpose_config_template, transpose_include_list)
self.register_templates('GarNet' , garnet_function_template, garnet_config_template, garnet_include_list)
self.register_templates('GarNetStack' , garnet_stack_function_template,garnet_stack_config_template, garnet_include_list)
self.register_templates('GarNetStack' , garnet_stack_function_template,garnet_stack_config_template, garnet_include_list)
self.register_templates('Embedding', embed_function_template, embed_config_template, embed_include_list)

def create_initial_config(self, part='xcku115-flvb2104-2-i', board=None, clock_period=5, io_type='io_parallel'):
config = {}
Expand Down
4 changes: 2 additions & 2 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def config_from_keras_model(model, granularity='model', default_precision='ap_fi
model_arch = json.loads(model.to_json())

#Define supported layers
core_layers = ['InputLayer', 'Dropout', 'Flatten', 'Reshape', 'Permute']
core_layers = ['InputLayer', 'Dropout', 'Flatten', 'Reshape', 'Permute', 'Embedding']
dense_layers = ['Dense', 'BinaryDense', 'TernaryDense']
conv_layers = ['Conv1D', 'Conv2D', 'BinaryConv2D']
pooling_layers = ['MaxPooling1D', 'MaxPooling2D', 'GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'AveragePooling1D', 'AveragePooling2D', 'GlobalAveragePooling1D', 'GlobalAveragePooling2D']
Expand Down Expand Up @@ -342,4 +342,4 @@ def config_from_onnx_model(model, granularity='model', default_precision='ap_fix

config['Model'] = model_config

return config
return config
45 changes: 45 additions & 0 deletions test/pytest/test_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import hls4ml
import numpy as np
from tensorflow.keras.models import model_from_json, Model
from tensorflow.keras.layers import Input, Permute, Concatenate, Activation, Embedding
import yaml

@pytest.fixture(scope='module')
def data():
X = np.random.randint(10, size=(32, 100))
return X

@pytest.fixture(scope='module')
def keras_model():
inputs = Input(shape=(100,), name='embedding_input')
embedding = Embedding(13, 8, input_length=100, name='embedding')(inputs)
model = Model(inputs=inputs, outputs=embedding)
return model

@pytest.fixture
@pytest.mark.parametrize('io_type', ['io_parallel',
'io_stream'])
def hls_model(keras_model, io_type):
hls_config = hls4ml.utils.config_from_keras_model(keras_model,
default_precision='ap_fixed<16,6>',
granularity='name')
hls_config['LayerName']['embedding_input']['Precision']['result'] = 'ap_uint<4>'
hls_model = hls4ml.converters.convert_from_keras_model(keras_model,
hls_config=hls_config,
io_type=io_type,
output_dir='hls4mlprj_embed_{}'.format(io_type))

hls_model.compile()
return hls_model

@pytest.mark.parametrize('io_type', ['io_parallel',
'io_stream'])
def test_accuracy(data, keras_model, hls_model):
X = data
model = keras_model
# model under test predictions and accuracy
y_keras = model.predict(X)
y_hls4ml = hls_model.predict(X.astype(np.float)).reshape(y_keras.shape)
# "accuracy" of hls4ml predictions vs keras
np.testing.assert_allclose(y_keras, y_hls4ml, rtol=0, atol=1e-03, verbose=True)