Skip to content

Commit bcd022b

Browse files
committed
Fix KLLoss layer example
1 parent 688c887 commit bcd022b

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

contrib/kl_layer/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ It was developed specifically for [AD@L1 CMS paper](https://www.nature.com/artic
1010

1111
# Usage
1212

13-
`test_extensions` function in `kl_layer.py` contains the example of how to use the KL layer.
13+
`kl_layer.py` contains the example of how to use the KL layer.
1414
To run do
1515

1616
```

contrib/kl_layer/kl_layer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct distance_config {
2020

2121
// Internal info
2222
static const unsigned table_size = 1024;
23-
static constexpr float exp_range = 1024;
23+
static constexpr unsigned exp_range = 8;
2424
};
2525

2626
template <typename CONFIG_T, int N_TABLE> void init_klloss_exp_table(typename CONFIG_T::exp_table_t table_out[N_TABLE]) {

contrib/kl_layer/kl_layer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
import hls4ml
2323
from hls4ml.converters.keras_to_hls import parse_default_keras_layer
24-
from hls4ml.model.types import FixedPrecisionType, NamedType
24+
from hls4ml.model.attributes import ConfigurableAttribute, TypeAttribute
25+
from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode
2526

2627

2728
# Keras implementation of a KL layer
@@ -47,21 +48,23 @@ def _merge_function(self, inputs):
4748
class HKLLoss(hls4ml.model.layers.Layer):
4849
'''hls4ml implementation of a KL loss custom layer'''
4950

51+
_expected_attributes = [
52+
ConfigurableAttribute('table_size', default=1024),
53+
ConfigurableAttribute('exp_range', default=8),
54+
TypeAttribute('accum'),
55+
TypeAttribute(
56+
'sum',
57+
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
58+
),
59+
TypeAttribute(
60+
'exp_table',
61+
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
62+
),
63+
]
64+
5065
def initialize(self):
5166
self.add_output_variable(shape=[1], dim_names=[f'KL_LOSS_{self.index}'])
5267

53-
print(self.attributes)
54-
if 'sum_t' not in self.attributes:
55-
self.set_attr('sum_t', self.get_attr('accum_t'))
56-
if 'exp_table_t' not in self.attributes:
57-
self.set_attr(
58-
'exp_table_t', NamedType(name=self.name + '_exp_table_t', precision=FixedPrecisionType(width=18, integer=8))
59-
)
60-
if 'table_size' not in self.attributes:
61-
self.set_attr('table_size', 1024)
62-
if 'exp_range' not in self.attributes:
63-
self.set_attr('exp_range', 8)
64-
6568

6669
# Templates
6770
distance_config_template = """struct config{index} : nnet::distance_config {{
@@ -73,8 +76,8 @@ def initialize(self):
7376
static const unsigned table_size = {table_size};
7477
static constexpr float exp_range = {exp_range};
7578
}};\n"""
76-
distance_function_template = 'nnet::{distance}<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});'
77-
distance_include_list = ['../../../contrib/kl_layer/kl_layer.h']
79+
distance_function_template = 'nnet::klloss<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});'
80+
distance_include_list = ['nnet_utils/kl_layer.h']
7881

7982

8083
class HKLLossConfigTemplate(hls4ml.backends.template.LayerConfigTemplate):
@@ -96,7 +99,6 @@ def __init__(self):
9699

97100
def format(self, node):
98101
params = {}
99-
params['distance'] = 'klloss'
100102
params['config'] = f'config{node.index}'
101103
params['input1_t'] = node.get_input_variable(node.inputs[0]).type.name
102104
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
@@ -134,7 +136,7 @@ def main():
134136
backend.register_template(HKLLossFunctionTemplate)
135137

136138
# Register HLS implementation
137-
p = Path('kl_layer.h')
139+
p = Path(__file__).parent / 'kl_layer.h'
138140
backend.register_source(p)
139141

140142
# Test if it works

0 commit comments

Comments
 (0)