Skip to content

Commit 5f19b38

Browse files
committed
io_stream
1 parent 6c01d6b commit 5f19b38

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

hls4ml/templates/vivado/nnet_utils/nnet_embed.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,30 @@ struct embed_config
3333
for (int j = 0; j < CONFIG_T::n_in; j++) {
3434
for (int i = 0; i < CONFIG_T::n_out; i++) {
3535
#pragma HLS UNROLL
36-
res[j * CONFIG_T::n_out + i] = (res_T) weights[data[j] * CONFIG_T::n_out + i];
36+
res[j * CONFIG_T::n_out + i] = weights[data[j] * CONFIG_T::n_out + i];
3737
}
3838
}
3939
}
4040

41+
template<class data_T, class res_T, typename CONFIG_T>
42+
void embedding(
43+
hls::stream<data_T> &data,
44+
hls::stream<res_T> &res,
45+
typename CONFIG_T::weight_t weights[CONFIG_T::vocab_size*CONFIG_T::n_out])
46+
{
47+
// copy over the corresponding row in the weights lookup table
48+
data_T in_data = data.read();
49+
res_T res_pack;
50+
#pragma HLS PIPELINE
51+
#pragma HLS DATA_PACK variable=res_pack
52+
for (int j = 0; j < data_T::size; j++) {
53+
for (int i = 0; i < CONFIG_T::n_out; i++) {
54+
#pragma HLS UNROLL
55+
res_pack[i] = weights[in_data[j] * CONFIG_T::n_out + i];
56+
}
57+
res.write(res_pack);
58+
}
59+
}
4160
}
4261

4362
#endif

0 commit comments

Comments
 (0)