Skip to content

Commit 8a94f4a

Browse files
committed
Better sin/cos LUT implementations
1 parent f2687df commit 8a94f4a

File tree

3 files changed

+170
-6
lines changed

3 files changed

+170
-6
lines changed

hls4ml/backends/symbolic/passes/expr_templates.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,19 @@
1313

1414
expr_include_list = ['hls_math.h', 'nnet_utils/nnet_math.h']
1515

16+
built_in_luts = ['sin_lut', 'cos_lut']
17+
1618
class HLSCodePrinter(CXX11CodePrinter):
1719
_ns = 'hls::'
1820

19-
def __init__(self, layer, lut_functions, settings=None):
21+
def __init__(self, layer, lut_functions, use_built_in_luts=False, settings=None):
2022
if lut_functions is not None:
23+
if use_built_in_luts:
24+
# Check if user's LUTs override built-in LUTs
25+
for lut_name in lut_functions.keys():
26+
if lut_name in built_in_luts:
27+
print(f'WARNING: User-specified LUT function {lut_name} overrides built-in LUT function.')
28+
2129
if settings is None:
2230
settings = { 'user_functions': lut_functions }
2331
else:
@@ -27,6 +35,7 @@ def __init__(self, layer, lut_functions, settings=None):
2735

2836
super().__init__(settings)
2937
self.layer = layer
38+
self.use_built_in_luts = use_built_in_luts
3039

3140
for k in ('Abs Sqrt exp exp2 expm1 log log10 log2 log1p Cbrt hypot fma'
3241
' loggamma sin cos tan asin acos atan atan2 sinh cosh tanh asinh acosh '
@@ -82,7 +91,14 @@ def _print_math(self, expr):
8291
cast = f'({hls_type.name})'
8392
args = ', '.join(map(lambda arg: self._print(arg), expr.args))
8493

85-
return f'{self._ns}{name}{template}({cast}({args}))'
94+
if self.use_built_in_luts and name + '_lut' in built_in_luts:
95+
ns = 'nnet::'
96+
name = name + '_lut'
97+
template = f'<{hls_type.name}>'
98+
else:
99+
ns = self._ns
100+
101+
return f'{ns}{name}{template}({cast}({args}))'
86102

87103
def _print_Symbol(self, expr):
88104
name = super()._print_Symbol(expr)
@@ -96,7 +112,8 @@ def __init__(self):
96112
def format(self, node):
97113
params = self._default_function_params(node)
98114

99-
printer = HLSCodePrinter(node, lut_functions={ lut_fun.name : lut_fun.name for lut_fun in params['lut_functions'] })
115+
lut_functions = { lut_fun.name : lut_fun.name for lut_fun in params['lut_functions'] }
116+
printer = HLSCodePrinter(node, lut_functions=lut_functions, use_built_in_luts=node.attributes['use_built_in_luts'])
100117

101118
fn_templates = []
102119
for i, expr in enumerate(node.attributes['expression']):

hls4ml/converters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def convert_from_symbolic_expression(
400400
expr,
401401
n_symbols=None,
402402
lut_functions=None,
403+
use_built_in_lut_functions=False,
403404
output_dir='my-hls-test',
404405
project_name='myproject',
405406
input_data_tb=None,
@@ -440,6 +441,7 @@ def convert_from_symbolic_expression(
440441
expr_layer['expression'] = [str(e) for e in expr]
441442
expr_layer['n_symbols'] = n_symbols
442443
expr_layer['lut_functions'] = lut_functions
444+
expr_layer['use_built_in_luts'] = use_built_in_lut_functions
443445
layer_list.append(expr_layer)
444446

445447
config = create_config(

hls4ml/templates/vivado/nnet_utils/nnet_math.h

Lines changed: 148 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,155 @@ T atan(T x) {
4545
};
4646

4747
template<typename T>
48-
T atan2(T x) {
49-
return (T) hls::atan2(x);
48+
T atan2(T x, T y) {
49+
return (T) hls::atan2(x, y);
5050
};
5151

52+
template<class T, int W, int I>
53+
void init_sincos_table(T table[1<<(W - I - 3)][2]) {
54+
unsigned int NTE = 1<<(W - I - 3); //No of table entries
55+
double step = M_PI/(4*NTE); //Interval between angles
56+
double y = 0;
57+
//double scaled_angle = 0;
58+
59+
for (unsigned int i=0; i < NTE; i++) {
60+
table[i][0] = std::cos(y);
61+
table[i][1] = std::sin(y);
62+
y += step;
63+
//scaled_angle = y/(2*M_PI);
64+
//printf("cos(%f) = %23.22f, sin(%f) = %23.22f index = %d, scaled angle = %13.12f \n", y, cos(y), y, sin(y), i, scaled_angle);
65+
}
66+
67+
}
68+
69+
template<class T>
70+
void sincos_lut(const T &input, T output[2]) {
71+
72+
#pragma HLS INLINE
73+
74+
static bool flag = true;
75+
if (flag && T::width-T::iwidth > 12) {
76+
#if !defined(__SYNTHESIS__) && defined(SINCOS_LUT_DEBUG)
77+
std::cout << "FILE : " << __FILE__ << ", LINE : " << __LINE__ << std::endl;
78+
std::cout << "Warning: The output of sincos_lut will not be accurate" << std::endl;
79+
#endif
80+
flag = false;
81+
}
82+
// Datatype for lookup table entries
83+
typedef ap_ufixed <T::width, T::iwidth, AP_RND> luttype;
84+
// Datatype for posinput which is used to handle negative inputs
85+
typedef ap_ufixed<T::width-T::iwidth, 0> posinputtype;
86+
87+
typedef ap_uint<9> lutindextype; // 9 bits required for indexing into 512 entry table
88+
typedef ap_uint<3> octanttype; // 3 bits required for octant value range of 0 thru 7
89+
T outputtemp[2];
90+
lutindextype luTdex = 0;
91+
posinputtype posinput = input;
92+
93+
// Initialize the lookup table
94+
#ifdef __SYNTHESIS__
95+
bool initialized = false;
96+
luttype sincos[512][2];
97+
#else
98+
static bool initialized = false;
99+
static luttype sincos[512][2];
100+
#endif
101+
if (!initialized) {
102+
init_sincos_table<luttype, 12, 0>(sincos);
103+
initialized = true;
104+
}
105+
106+
// Leaving this commented out makes the table to to BRAM
107+
//#pragma HLS ARRAY_PARTITION variable=sincos complete dim=0
108+
109+
typedef ap_uint<AP_MAX(T::width-T::iwidth-3, 1)> lutindextype1;
110+
// Extracting (MSB-3:LSB) bits of scaled input to determine the lookup table index
111+
lutindextype1 luTdex1 = posinput.range(AP_MAX(T::width-T::iwidth-3, 1), 0); // Extracting the lookup table index
112+
113+
if (T::width-T::iwidth>=4 && T::width-T::iwidth<=12) {
114+
luTdex(8, 12- (T::width - T::iwidth)) = luTdex1; // stride
115+
}
116+
//Approximation for the scaled inputs whose number of bits are greater than 12
117+
else if (T::width-T::iwidth>12) {
118+
// Lookup table index for the scaled inputs whose number of bits are greater than 12
119+
luTdex = luTdex1/(1<<(AP_MAX(T::width-T::iwidth-12, 0)));
120+
if ((luTdex1 % (1<<(AP_MAX(T::width-T::iwidth-12,0)))) > (1<<(AP_MAX(T::width-T::iwidth-13,0)))) {
121+
luTdex = luTdex + 1;
122+
}
123+
typedef ap_ufixed<AP_MAX((AP_MAX(T::width-T::iwidth-3, 1) + T::width-T::iwidth-12), 1), AP_MAX(T::width-T::iwidth-3, 1)> datatype;
124+
datatype x = (datatype)luTdex1;
125+
x = x >> AP_MAX(T::width-T::iwidth-12, 0);
126+
if (x > 511.5) { luTdex = 511; }
127+
if (luTdex1 <= 1<<(AP_MAX(T::width-T::iwidth-13,0)) && luTdex1 != 0) { luTdex = 1; }
128+
}
129+
130+
if (T::width-T::iwidth>=3) {
131+
// Getting the octant 0-7 by extracting the first 3 bits from MSB side of scaled input where
132+
// octant 0 corresponds to [0-PI/4),
133+
// octant 1 corresponds to [PI/4-2PI/4),
134+
// octant 2 corresponds to [2PI/4-3PI/4) and so on
135+
//octanttype octant = posinput.template slc<3>(T::width-T::iwidth-3);
136+
octanttype octant = posinput(T::width-T::iwidth-1, T::width-T::iwidth-3);
137+
luTdex = (octant[0] == 1)?(lutindextype)(512-luTdex):(lutindextype)(luTdex);
138+
// imaginary part is sine
139+
outputtemp[1] = ((octant==0) | (octant==3)) ? (T) sincos[luTdex][1]:
140+
((octant==2) | (octant==1)) ? (T) sincos[luTdex][0]:
141+
((octant==7) | (octant==4)) ? (T)-sincos[luTdex][1]:
142+
(T)-sincos[luTdex][0];
143+
// real part is cosine
144+
outputtemp[0] = ((octant==6) | (octant==1)) ? (T) sincos[luTdex][1]:
145+
((octant==3) | (octant==4)) ? (T)-sincos[luTdex][0]:
146+
((octant==2) | (octant==5)) ? (T)-sincos[luTdex][1]:
147+
(T) sincos[luTdex][0];
148+
// Below two are the cases when the output corresponds to + or - (0 or 1) for which there is no entry in the lookup table
149+
output[1] = ((posinput==0.125) | (posinput==0.375)) ? T( 0.7071067811865475244008):
150+
((posinput==0.625) | (posinput==0.875)) ? T(-0.7071067811865475244008):
151+
outputtemp[1];
152+
output[0] = ((posinput==0.125) | (posinput==0.875)) ? T( 0.7071067811865475244008):
153+
((posinput==0.375) | (posinput==0.625)) ? T(-0.7071067811865475244008):
154+
outputtemp[0];
155+
}
156+
157+
if (T::width-T::iwidth <= 2) {
158+
output[1] = (posinput==0 ) ? (T) 0:
159+
(posinput==0.25) ? (T) 1:
160+
(posinput==0.5 ) ? (T) 0:
161+
(posinput==0.75) ? (T)-1:
162+
outputtemp[1];
163+
output[0] = (posinput==0 ) ? (T) 1:
164+
(posinput==0.25) ? (T) 0:
165+
(posinput==0.5 ) ? (T)-1:
166+
(posinput==0.75) ? (T) 0:
167+
outputtemp[0];
168+
}
169+
170+
#if !defined(__SYNTHESIS__) && defined(SINCOS_LUT_DEBUG)
171+
std::cout << "FILE : " << __FILE__ << ", LINE : " << __LINE__ << std::endl;
172+
std::cout << "============AP_FIXED SINCOS======================" << std::endl;
173+
std::cout << "positive input is = " << posinput << std::endl;
174+
std::cout << "lut index is = " << luTdex << std::endl;
175+
std::cout << "sin value is = " << output[1] << std::endl;
176+
std::cout << "cos value is = " << output[0] << std::endl;
177+
std::cout << "=================================================" << std::endl;
178+
#endif
179+
}
180+
181+
template<class T>
182+
T sin_lut(const T input) {
183+
T sincos_res[2];
184+
T scaled_input = input * ap_ufixed<16,0>(0.15915494309); // 1/(2*pi)
185+
sincos_lut(scaled_input, sincos_res);
186+
return sincos_res[1];
187+
}
188+
189+
template<class T>
190+
T cos_lut(const T input) {
191+
T sincos_res[2];
192+
T scaled_input = input * ap_ufixed<16,0>(0.15915494309); // 1/(2*pi)
193+
sincos_lut(scaled_input, sincos_res);
194+
return sincos_res[0];
195+
}
196+
52197
}
53198

54-
#endif
199+
#endif

0 commit comments

Comments
 (0)