Skip to content

Commit 4a43f64

Browse files
committed
Quartus Softmax optimize LUT to store only used values
1 parent d98861d commit 4a43f64

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,18 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
130130
enum class softmax_implementation {latency=0, legacy=1, stable=2};
131131

132132
template<class data_T, typename CONFIG_T>
133-
inline unsigned softmax_idx_from_real_val(const data_T x){
133+
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
134+
// Number of address bits for table
135+
static constexpr int N = ceillog2(CONFIG_T::table_size);
136+
137+
// Slice the top N bits of the input
138+
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
139+
return y.to_uint();
140+
}
141+
142+
143+
template<class data_T, typename CONFIG_T>
144+
inline unsigned softmax_latency_idx_from_real_val(const data_T x){
134145
// Number of address bits for table
135146
static constexpr int N = ceillog2(CONFIG_T::table_size);
136147

@@ -148,27 +159,20 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
148159
// Find maximum
149160
Op_max<data_T> op_max;
150161
hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);
151-
152-
// Calculate differences from the maximum, forcing rounding and saturation for better accuracy
153-
hls_register ac_fixed<data_T::width, data_T::i_width, true, AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
154-
#pragma unroll
155-
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
156-
d_xi_xmax[i] = data[i] - x_max;
157-
}
158162

159163
// Calculate all the e^x's
160164
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
161165
#pragma unroll
162166
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
163-
exp_res[i] = exp_table[softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i])];
167+
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data[i] - x_max)];
164168
}
165169

166170
// Explicitly sum previously calculated exponentials with an adder tree
167171
Op_add<typename CONFIG_T::exp_table_t> op_add;
168172
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
169173

170174
// Multiply previously calculated exponetials with the reciprocal of the sum
171-
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
175+
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
172176
#pragma unroll
173177
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
174178
res[i] = exp_res[i] * inv_exp_sum;
@@ -178,31 +182,22 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
178182
// TODO - Improve accuracy
179183
template <class data_T, class res_T, typename CONFIG_T>
180184
void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
181-
/*
182-
* Note: The latency tables are equivalent to stable tables
183-
* However, the compiler cannot include the same table twice
184-
* Therefore, an out-of-scope exception is thrown in one of the functions
185-
* Temporary solution - Create the same table twice in quartus_writer.py
186-
* Long-term solution - Only create tables needed by the network;
187-
* Currently, quartus-writer.py generates LUTs for all activations,
188-
* Regardless if they are present in the network or not
189-
*/
190185
#include "activation_tables/exp_table_latency.tb"
191186
#include "activation_tables/invert_table_latency.tb"
192187

193188
// Calculate all the e^x's
194189
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
195190
#pragma unroll
196191
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
197-
exp_res[i] = exp_table_latency[softmax_idx_from_real_val<data_T, CONFIG_T>(data[i])];
192+
exp_res[i] = exp_table_latency[softmax_latency_idx_from_real_val<data_T, CONFIG_T>(data[i])];
198193
}
199194

200195
// Explicitly sum the results with an adder tree.
201196
Op_add<typename CONFIG_T::exp_table_t> op_add;
202197
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
203198

204199
// Multiply previously calculated exponetials with the reciprocal of the sum
205-
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
200+
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
206201
#pragma unroll
207202
for(unsigned i = 0; i < CONFIG_T::n_in; i++){
208203
res[i] = exp_res[i] * inv_exp_sum;

hls4ml/writer/quartus_writer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,12 +653,19 @@ def __write_exp_table(self, model, path):
653653
except:
654654
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
655655
pass
656+
if fp_signed is False:
657+
raise Exception('Softmax types need to be signed')
656658

657659
sep = ''
658660
N = ceil_log2(table_size)
659661
for i in range(table_size):
660662
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
661-
f.set_msb_bits(uint_to_binary(i, N))
663+
b = uint_to_binary(i, N)
664+
if i == 0:
665+
b.insert(0, 0)
666+
else:
667+
b.insert(0, 1)
668+
f.set_msb_bits(b)
662669
real_val = f.exp_float()
663670
h_file.write(sep + str(real_val))
664671
sep = ", "
@@ -693,20 +700,24 @@ def __write_invert_table(self, model, path):
693700
except:
694701
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
695702
pass
703+
if fp_signed is False:
704+
raise Exception('Softmax types need to be signed')
696705

697706
sep = ''
698707
N = ceil_log2(table_size)
699708
for i in range(table_size):
700709
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
701-
f.set_msb_bits(uint_to_binary(i, N))
710+
b = uint_to_binary(i, N)
711+
b.insert(0, 0)
712+
f.set_msb_bits(b)
702713
real_val = f.inv_float()
703714
h_file.write(sep + str(real_val))
704715
sep = ", "
705716

706717
h_file.write('};\n')
707718
h_file.write('\n#endif\n')
708719
h_file.close()
709-
720+
710721
def __write_exp_table_latency(self, model, path):
711722
table_name = 'exp_table_latency'
712723
table_size = self.__get_table_size(model, 'softmax')

0 commit comments

Comments
 (0)