Skip to content

Commit bf5762d

Browse files
committed
Quartus Softmax optimize LUT to store only used values
1 parent c8e8f75 commit bf5762d

File tree

4 files changed

+33
-32
lines changed

4 files changed

+33
-32
lines changed

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,17 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
130130
enum class softmax_implementation {latency=0, legacy=1, stable=2};
131131

132132
template<class data_T, typename CONFIG_T>
133-
inline unsigned softmax_idx_from_real_val(const data_T x){
133+
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
134+
// Number of address bits for table
135+
static constexpr int N = ceillog2(CONFIG_T::table_size);
136+
137+
// Slice the top N bits of the input
138+
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
139+
return y.to_uint();
140+
}
141+
142+
template<class data_T, typename CONFIG_T>
143+
inline unsigned softmax_latency_idx_from_real_val(const data_T x){
134144
// Number of address bits for table
135145
static constexpr int N = ceillog2(CONFIG_T::table_size);
136146

@@ -148,27 +158,20 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
148158
// Find maximum
149159
Op_max<data_T> op_max;
150160
hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);
151-
152-
// Calculate differences from the maximum, forcing rounding and saturation for better accuracy
153-
hls_register ac_fixed<data_T::width, data_T::i_width, true, AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
154-
#pragma unroll
155-
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
156-
d_xi_xmax[i] = data[i] - x_max;
157-
}
158161

159162
// Calculate all the e^x's
160163
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
161164
#pragma unroll
162165
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
163-
exp_res[i] = exp_table[softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i])];
166+
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data[i] - x_max)];
164167
}
165168

166169
// Explicitly sum previously calculated exponentials with an adder tree
167170
Op_add<typename CONFIG_T::exp_table_t> op_add;
168171
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
169172

170173
// Multiply previously calculated exponetials with the reciprocal of the sum
171-
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
174+
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
172175
#pragma unroll
173176
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
174177
res[i] = exp_res[i] * inv_exp_sum;
@@ -178,31 +181,22 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
178181
// TODO - Improve accuracy
179182
template <class data_T, class res_T, typename CONFIG_T>
180183
void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
181-
/*
182-
* Note: The latency tables are equivalent to stable tables
183-
* However, the compiler cannot include the same table twice
184-
* Therefore, an out-of-scope exception is thrown in one of the functions
185-
* Temporary solution - Create the same table twice in quartus_writer.py
186-
* Long-term solution - Only create tables needed by the network;
187-
* Currently, quartus-writer.py generates LUTs for all activations,
188-
* Regardless if they are present in the network or not
189-
*/
190184
#include "activation_tables/exp_table_latency.tb"
191185
#include "activation_tables/invert_table_latency.tb"
192186

193187
// Calculate all the e^x's
194188
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
195189
#pragma unroll
196190
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
197-
exp_res[i] = exp_table_latency[softmax_idx_from_real_val<data_T, CONFIG_T>(data[i])];
191+
exp_res[i] = exp_table_latency[softmax_latency_idx_from_real_val<data_T, CONFIG_T>(data[i])];
198192
}
199193

200194
// Explicitly sum the results with an adder tree.
201195
Op_add<typename CONFIG_T::exp_table_t> op_add;
202196
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
203197

204198
// Multiply previously calculated exponetials with the reciprocal of the sum
205-
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
199+
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
206200
#pragma unroll
207201
for(unsigned i = 0; i < CONFIG_T::n_in; i++){
208202
res[i] = exp_res[i] * inv_exp_sum;

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,15 @@ void softmax_stable(stream<data_T> &data, stream<res_T> &res) {
283283
hls_register typename CONFIG_T::exp_table_t exp_res[data_T::size];
284284
#pragma unroll
285285
for(unsigned j = 0; j < data_T::size; j++) {
286-
exp_res[j] = exp_table[softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j])];
286+
exp_res[j] = exp_table[softmax_stable_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j])];
287287
}
288288

289289
// Explicitly sum the results with an adder tree.
290290
// Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing
291291
Op_add<typename CONFIG_T::exp_table_t> op_add;
292292
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, data_T::size, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
293293

294-
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
294+
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_stable_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
295295
res_T out_pack;
296296

297297
SoftmaxInvPackLoop:
@@ -327,7 +327,7 @@ void softmax_latency(stream<data_T> &data, stream<res_T> &res){
327327
SoftmaxExpPackLoop:
328328
#pragma unroll
329329
for(unsigned j = 0; j < data_T::size; j++) {
330-
exp_res[j] = exp_table_latency[softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j])];
330+
exp_res[j] = exp_table_latency[softmax_latency_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j])];
331331
}
332332

333333
// Explicitly sum the results with an adder tree.
@@ -336,7 +336,7 @@ void softmax_latency(stream<data_T> &data, stream<res_T> &res){
336336
hls_register typename CONFIG_T::exp_table_t exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
337337

338338
// Multiply previously calculated exponetials with the reciprocal of the sum
339-
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
339+
hls_register typename CONFIG_T::inv_table_t inv_exp_sum = invert_table_latency[softmax_latency_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)];
340340

341341
res_T out_pack;
342342
SoftmaxInvPackLoop:

hls4ml/utils/fixed_point_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import sys
22
import math
3-
from unicodedata import decimal
4-
5-
from numpy import integer
63

74
'''
85
A helper class for handling fixed point methods
@@ -70,7 +67,6 @@ def set_msb_bits(self, bits):
7067
self.integer_bits[i] = bits[i]
7168
elif i >= self.I and i<self.N:
7269
self.decimal_bits[i-self.I] = bits[i]
73-
# print('Len bits ' + str(len(bits)) + ' Inside FPU ' + str(self.integer_bits) + str(self.decimal_bits))
7470

7571
'''
7672
Returns e^x, where x is the current fixed point number

hls4ml/writer/quartus_writer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -918,12 +918,19 @@ def __write_exp_table(self, model, path):
918918
except:
919919
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
920920
pass
921+
if fp_signed is False:
922+
raise Exception('Softmax types need to be signed')
921923

922924
sep = ''
923925
N = ceil_log2(table_size)
924926
for i in range(table_size):
925927
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
926-
f.set_msb_bits(uint_to_binary(i, N))
928+
b = uint_to_binary(i, N)
929+
if i == 0:
930+
b.insert(0, 0)
931+
else:
932+
b.insert(0, 1)
933+
f.set_msb_bits(b)
927934
real_val = f.exp_float()
928935
h_file.write(sep + str(real_val))
929936
sep = ", "
@@ -957,19 +964,23 @@ def __write_invert_table(self, model, path):
957964
except:
958965
# FixedPrecisionType wasn't correctly stored in layer attributes, use default values
959966
pass
967+
if fp_signed is False:
968+
raise Exception('Softmax types need to be signed')
960969

961970
sep = ''
962971
N = ceil_log2(table_size)
963972
for i in range(table_size):
964973
f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed)
965-
f.set_msb_bits(uint_to_binary(i, N))
974+
b = uint_to_binary(i, N)
975+
b.insert(0, 0)
976+
f.set_msb_bits(b)
966977
real_val = f.inv_float()
967978
h_file.write(sep + str(real_val))
968979
sep = ", "
969980

970981
h_file.write('};\n')
971982
h_file.close()
972-
983+
973984
def __write_exp_table_latency(self, model, path):
974985
table_name = 'exp_table_latency'
975986
table_size = self.__get_table_size(model, 'softmax')

0 commit comments

Comments
 (0)