Skip to content

Commit dea24dc

Browse files
authored
Merge pull request #660 from vloncar/sr
Symbolic expressions in hls4ml
2 parents 4b4b5a0 + 2f09c0f commit dea24dc

File tree

19 files changed

+1084
-3
lines changed

19 files changed

+1084
-3
lines changed

hls4ml/backends/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from hls4ml.backends.backend import Backend, get_available_backends, get_backend, register_backend # noqa: F401
22
from hls4ml.backends.fpga.fpga_backend import FPGABackend # noqa: F401
33
from hls4ml.backends.quartus.quartus_backend import QuartusBackend
4+
from hls4ml.backends.symbolic.symbolic_backend import SymbolicExpressionBackend
45
from hls4ml.backends.vivado.vivado_backend import VivadoBackend
56
from hls4ml.backends.vivado_accelerator.vivado_accelerator_backend import VivadoAcceleratorBackend
67
from hls4ml.backends.vivado_accelerator.vivado_accelerator_config import VivadoAcceleratorConfig # noqa: F401
@@ -11,3 +12,4 @@
1112
register_backend('VivadoAccelerator', VivadoAcceleratorBackend)
1213
register_backend('Vitis', VitisBackend)
1314
register_backend('Quartus', QuartusBackend)
15+
register_backend('SymbolicExpression', SymbolicExpressionBackend)

hls4ml/backends/symbolic/__init__.py

Whitespace-only changes.

hls4ml/backends/symbolic/passes/__init__.py

Whitespace-only changes.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import re
2+
3+
from sympy.core import S
4+
from sympy.core.numbers import Integer
5+
from sympy.printing.cxx import CXX11CodePrinter
6+
7+
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
8+
from hls4ml.model.layers import SymbolicExpression
9+
10+
# Expression templates
11+
12+
expr_function_template = 'y[{y_index}] = {expr_str};'
13+
14+
expr_include_list = ['hls_math.h', 'nnet_utils/nnet_math.h']
15+
16+
built_in_luts = ['sin_lut', 'cos_lut']
17+
18+
19+
class HLSCodePrinter(CXX11CodePrinter):
20+
_ns = 'hls::'
21+
22+
def __init__(self, layer, lut_functions, use_built_in_luts=False, settings=None):
23+
if lut_functions is not None:
24+
if use_built_in_luts:
25+
# Check if user's LUTs override built-in LUTs
26+
for lut_name in lut_functions.keys():
27+
if lut_name in built_in_luts:
28+
print(f'WARNING: User-specified LUT function {lut_name} overrides built-in LUT function.')
29+
30+
if settings is None:
31+
settings = {'user_functions': lut_functions}
32+
else:
33+
user_functions = settings.get('user_functions', {})
34+
user_functions.update(lut_functions)
35+
settings['user_functions'] = user_functions
36+
37+
super().__init__(settings)
38+
self.layer = layer
39+
self.use_built_in_luts = use_built_in_luts
40+
41+
for k in (
42+
'Abs Sqrt exp exp2 expm1 log log10 log2 log1p Cbrt hypot fma'
43+
' loggamma sin cos tan asin acos atan atan2 sinh cosh tanh asinh acosh '
44+
'atanh erf erfc loggamma gamma ceiling floor'
45+
).split():
46+
setattr(HLSCodePrinter, '_print_%s' % k, HLSCodePrinter._print_math)
47+
48+
def _symbol_to_array(self, name):
49+
return re.sub(r'([a-zA-Z]+)(\d+)', r'\1[\2]', name)
50+
51+
def _wrap_with_type_name(self, expr_str):
52+
type_name = self.layer.types['result_t'].name
53+
return f'{type_name}({expr_str})'
54+
55+
def _print_Integer(self, expr):
56+
int_str = super()._print_Integer(expr)
57+
return self._wrap_with_type_name(int_str)
58+
59+
def _print_Float(self, flt):
60+
float_str = super()._print_Float(flt)
61+
return self._wrap_with_type_name(float_str)
62+
63+
def _print_Rational(self, expr):
64+
p, q = int(expr.p), int(expr.q)
65+
p_q_str = f'{p}.0/{q}.0'
66+
return self._wrap_with_type_name(p_q_str)
67+
68+
def _print_Pow(self, expr):
69+
type_name = self.layer.types['result_t'].name
70+
type_precision = self.layer.types['result_t'].precision
71+
if isinstance(expr.exp, Integer):
72+
l_brac, r_brac = ('(', ')') if len(expr.base.args) > 1 else ('', '')
73+
if expr.exp > 1:
74+
return (
75+
'('
76+
+ '*'.join([l_brac + self._symbol_to_array(self._print(expr.base)) + r_brac for _ in range(expr.exp)])
77+
+ ')'
78+
)
79+
elif expr.exp == -1: # 1/x
80+
base = l_brac + self._symbol_to_array(self._print(expr.base)) + r_brac
81+
return f'hls::recip<{type_precision.width}, {type_precision.integer}>(({type_name}){base})'
82+
else:
83+
return super()._print_Pow(expr)
84+
else:
85+
base = self._print(expr.base)
86+
if expr.exp == 0.5:
87+
return f'{self._ns}sqrt<{type_precision.width}, {type_precision.integer}>(({type_name})({base}))'
88+
elif expr.exp == S.One / 3:
89+
return f'{self._ns}cbrt<{type_precision.width}, {type_precision.integer}>(({type_name})({base}))'
90+
else:
91+
exp = self._print(expr.exp)
92+
return f'{self._ns}pow<{type_precision.width}, {type_precision.integer}>(({type_name})({base}), {exp})'
93+
94+
def _print_math(self, expr):
95+
name = self.known_functions[expr.__class__.__name__]
96+
if not isinstance(name, str):
97+
for cb, fname in name:
98+
if cb(*expr.args):
99+
name = fname
100+
break
101+
else:
102+
raise ValueError("No matching printer")
103+
104+
# Setting precision of math functions required some rethinking
105+
# Doing e.g., hls::pow<x.width, x.iwidth>(x, y) passes C sim, but fails synthesis, need to use hls::pow<16,6>(x,y)
106+
type_name = self.layer.types['result_t'].name
107+
type_precision = self.layer.types['result_t'].precision
108+
template = f'<{type_precision.width}, {type_precision.integer}>'
109+
cast = f'({type_name})'
110+
args = ', '.join(map(lambda arg: self._print(arg), expr.args))
111+
112+
if self.use_built_in_luts and name + '_lut' in built_in_luts:
113+
ns = 'nnet::'
114+
name = name + '_lut'
115+
template = f'<{type_name}>'
116+
else:
117+
ns = self._ns
118+
119+
return f'{ns}{name}{template}({cast}({args}))'
120+
121+
def _print_Symbol(self, expr):
122+
name = super()._print_Symbol(expr)
123+
return self._symbol_to_array(name)
124+
125+
126+
class ExpressionFunctionTemplate(FunctionCallTemplate):
127+
def __init__(self):
128+
super().__init__(SymbolicExpression, include_header=expr_include_list)
129+
self.template = expr_function_template
130+
131+
def format(self, node):
132+
params = self._default_function_params(node)
133+
134+
lut_functions = {lut_fun.name: lut_fun.name for lut_fun in params['lut_functions']}
135+
printer = HLSCodePrinter(node, lut_functions=lut_functions, use_built_in_luts=node.attributes['use_built_in_luts'])
136+
137+
fn_templates = []
138+
for i, expr in enumerate(node.attributes['expression']):
139+
params['expr_str'] = printer.doprint(expr)
140+
params['y_index'] = str(i)
141+
fn_templates.append(self.template.format(**params))
142+
143+
return fn_templates
144+
145+
146+
class ExpressionConfigTemplate(LayerConfigTemplate):
147+
def __init__(self):
148+
super().__init__(SymbolicExpression)
149+
150+
def format(self, node):
151+
params = self._default_config_params(node)
152+
153+
lut_defs = []
154+
for lut_fun in params['lut_functions']:
155+
type_name = params['result_t'].name
156+
if lut_fun.math_func in ['sinpi', 'cospi', 'sin', 'cos', 'asin', 'acos', 'atan', 'atan2']:
157+
# We have return type overrides for these functions
158+
namespace = 'nnet::'
159+
else:
160+
namespace = 'hls::'
161+
lut_def = (
162+
f'nnet::lookup_table<{type_name}, '
163+
f'{lut_fun.table_size}, '
164+
f'{namespace}'
165+
f'{lut_fun.math_func}> '
166+
f'{lut_fun.name}'
167+
f'({lut_fun.range_start}, '
168+
f'{lut_fun.range_end});'
169+
)
170+
lut_defs.append(lut_def)
171+
172+
return '\n'.join(lut_defs)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from hls4ml.model.layers import SymbolicExpression
2+
from hls4ml.model.optimizer import ConfigurableOptimizerPass
3+
4+
5+
class ValidateUserLookupTable(ConfigurableOptimizerPass):
6+
'''Validates the precision of user-defined LUTs is adequate'''
7+
8+
def __init__(self):
9+
self.raise_exception = False
10+
11+
def match(self, node):
12+
return isinstance(node, SymbolicExpression) and len(node.get_attr('lut_functions', [])) > 0
13+
14+
def transform(self, model, node):
15+
precision = node.get_output_variable().type.precision
16+
range = 2 ** (precision.integer - precision.signed)
17+
frac_step = 1 / 2**precision.fractional
18+
19+
for lut_fn in node.get_attr('lut_functions'):
20+
lut_range = lut_fn.range_end - lut_fn.range_start
21+
lut_step = lut_range / lut_fn.table_size
22+
23+
if lut_step < frac_step:
24+
msg = f'LUT function {lut_fn.name} requires more fractional bits.'
25+
if self.raise_exception:
26+
raise Exception(msg)
27+
else:
28+
print('WARNING:', msg)
29+
30+
if lut_range > range:
31+
msg = f'LUT function {lut_fn.name} requires more integer bits.'
32+
if self.raise_exception:
33+
raise Exception(msg)
34+
else:
35+
print('WARNING:', msg)
36+
37+
return False
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
import sys
3+
4+
from hls4ml.backends import FPGABackend
5+
from hls4ml.model.flow import register_flow
6+
from hls4ml.report import parse_vivado_report
7+
8+
9+
class SymbolicExpressionBackend(FPGABackend):
10+
def __init__(self):
11+
super().__init__('SymbolicExpression')
12+
self._register_flows()
13+
14+
def _register_flows(self):
15+
vivado_types = [
16+
'vivado:transform_types',
17+
]
18+
vivado_types_flow = register_flow('specific_types', vivado_types, requires=None, backend=self.name)
19+
20+
validation_passes = [
21+
'symbolicexpression:validate_user_lookup_table',
22+
]
23+
validation_flow = register_flow('validation', validation_passes, requires=None, backend=self.name)
24+
25+
template_flow = register_flow('apply_templates', self._get_layer_templates, requires=None, backend=self.name)
26+
27+
writer_passes = ['make_stamp', 'symbolicexpression:write_hls']
28+
self._writer_flow = register_flow('write', writer_passes, requires=['vivado:ip'], backend=self.name)
29+
30+
ip_flow_requirements = [vivado_types_flow, validation_flow, template_flow]
31+
ip_flow_requirements = list(filter(None, ip_flow_requirements))
32+
33+
self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)
34+
35+
def get_default_flow(self):
36+
return self._default_flow
37+
38+
def get_writer_flow(self):
39+
return self._writer_flow
40+
41+
def create_initial_config(
42+
self,
43+
part='xcvu9p-flga2577-2-e',
44+
clock_period=5,
45+
io_type='io_parallel',
46+
compiler='vivado_hls',
47+
hls_include_path=None,
48+
hls_libs_path=None,
49+
):
50+
config = {}
51+
52+
config['Part'] = part if part is not None else 'xcvu9p-flga2577-2-e'
53+
config['ClockPeriod'] = clock_period
54+
config['IOType'] = io_type
55+
config['Compiler'] = compiler if compiler is not None else 'vivado_hls'
56+
if not all([hls_include_path, hls_libs_path]):
57+
# Try to infer the include path from Vivado path
58+
bin_path = os.popen(f'command -v {compiler}').read().strip()
59+
if hls_include_path is None:
60+
hls_include_path = bin_path.replace(f'/bin/{compiler}', '/include')
61+
if not os.path.exists(hls_include_path + '/hls_math.h'):
62+
raise Exception(
63+
'Vivado HLS header files not found. Make sure you pass the proper path '
64+
'to the "include" directory (for example "/opt/Xilinx/Vivado/2020.1/include").'
65+
)
66+
elif hls_include_path == '':
67+
print(
68+
'No HLS include path provided, using HLS math functions from Python (i.e., predict()) will not work. '
69+
'Consider using only LUT approximations.'
70+
)
71+
if hls_libs_path is None:
72+
hls_libs_path = bin_path.replace(f'/bin/{compiler}', '/lnx64')
73+
if not os.path.exists(hls_libs_path + '/lib/csim/libhlsmc++-GCC46.so'):
74+
raise Exception(
75+
'Vivado HLS libraries not found. Make sure you pass the proper path '
76+
'to the "lnx64" directory (for example "/opt/Xilinx/Vivado/2020.1/lnx64").'
77+
)
78+
config['HLSIncludePath'] = hls_include_path
79+
config['HLSLibsPath'] = hls_libs_path
80+
config['HLSConfig'] = {}
81+
82+
return config
83+
84+
def build(self, model, reset=False, csim=True, synth=True, cosim=False, validation=False, export=False, vsynth=False):
85+
if 'linux' in sys.platform:
86+
found = os.system('command -v vivado_hls > /dev/null')
87+
if found != 0:
88+
raise Exception('Vivado HLS installation not found. Make sure "vivado_hls" is on PATH.')
89+
90+
curr_dir = os.getcwd()
91+
os.chdir(model.config.get_output_dir())
92+
vivado_cmd = (
93+
f'vivado_hls -f build_prj.tcl "reset={reset} '
94+
f'csim={csim} '
95+
f'synth={synth} '
96+
f'cosim={cosim} '
97+
f'validation={validation} '
98+
f'export={export} '
99+
f'vsynth={vsynth}"'
100+
)
101+
os.system(vivado_cmd)
102+
os.chdir(curr_dir)
103+
104+
return parse_vivado_report(model.config.get_output_dir())

0 commit comments

Comments
 (0)