Skip to content

Additional cleanup of the codebase #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions contrib/garnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@
"""

import tensorflow.keras as keras
from qkeras import QActivation, QDense, ternary

K = keras.backend

try:
from qkeras import QActivation, QDense, ternary

class NamedQDense(QDense):
def add_weight(self, name=None, **kwargs):
return super().add_weight(name=f'{self.name}_{name}', **kwargs)
class NamedQDense(QDense):
def add_weight(self, name=None, **kwargs):
return super().add_weight(name=f'{self.name}_{name}', **kwargs)

def ternary_1_05():
return ternary(alpha=1.0, threshold=0.5)

except ImportError:
pass
def ternary_1_05():
return ternary(alpha=1.0, threshold=0.5)


# Hack keras Dense to propagate the layer name into saved weights
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/backends/quartus/passes/merge_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self):
def format(self, node):
inp1 = node.get_input_variable(node.inputs[0])
inp2 = node.get_input_variable(node.inputs[1])
params = node._default_config_params()
params = self._default_config_params(node)
params['n_out'] = 1
params['n_in'] = inp1.shape[0]
params['product_type'] = get_backend('quartus').product_type(inp1.type.precision, inp2.type.precision)
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/backends/vivado/passes/merge_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self):
def format(self, node):
inp1 = node.get_input_variable(node.inputs[0])
inp2 = node.get_input_variable(node.inputs[1])
params = node._default_config_params()
params = self._default_config_params(node)
params['n_out'] = 1
params['n_in'] = inp1.shape[0]
params['product_type'] = get_backend('vivado').product_type(inp1.type.precision, inp2.type.precision)
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
elif model_type == 'onnx':
register_onnx_layer_handler(layer, func)

except ImportError:
except ImportError as err:
print(f'WARNING: Failed to import handlers from {module}: {err.msg}.')
continue


Expand Down
35 changes: 1 addition & 34 deletions hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer
from hls4ml.model.types import IntegerPrecisionType, Quantizer
from hls4ml.model.types import BinaryQuantizer, IntegerPrecisionType, TernaryQuantizer


@keras_handler('InputLayer')
Expand All @@ -25,37 +23,6 @@ def parse_input_layer(keras_layer, input_names, input_shapes, data_reader):
return layer, output_shape


class BinaryQuantizer(Quantizer):
def __init__(self, bits=2):
if bits == 1:
hls_type = IntegerPrecisionType(width=1, signed=False)
elif bits == 2:
hls_type = IntegerPrecisionType(width=2)
else:
raise Exception(f'BinaryQuantizer suppots 1 or 2 bits, but called with bits={bits}')
super().__init__(bits, hls_type)

def __call__(self, data):
zeros = np.zeros_like(data)
ones = np.ones_like(data)
quant_data = data
if self.bits == 1:
quant_data = np.where(data > 0, ones, zeros).astype('int')
if self.bits == 2:
quant_data = np.where(data > 0, ones, -ones)
return quant_data


class TernaryQuantizer(Quantizer):
def __init__(self):
super().__init__(2, IntegerPrecisionType(width=2))

def __call__(self, data):
zeros = np.zeros_like(data)
ones = np.ones_like(data)
return np.where(data > 0.5, ones, np.where(data <= -0.5, -ones, zeros))


dense_layers = ['Dense', 'BinaryDense', 'TernaryDense']


Expand Down
207 changes: 131 additions & 76 deletions hls4ml/converters/keras/qkeras.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,9 @@
import tensorflow as tf
from qkeras.quantizers import get_quantizer

from hls4ml.converters.keras.core import BinaryQuantizer
from hls4ml.model.types import ExponentPrecisionType, FixedPrecisionType, IntegerPrecisionType, Quantizer, XnorPrecisionType


class QKerasQuantizer(Quantizer):
def __init__(self, config):
self.quantizer_fn = get_quantizer(config)
self.alpha = config['config'].get('alpha', None)
if config['class_name'] == 'quantized_bits':
self.bits = config['config']['bits']
self.hls_type = get_type(config)
# ! includes stochastic_ternary
elif 'ternary' in config['class_name']:
self.bits = 2
self.hls_type = IntegerPrecisionType(width=2, signed=True)
# ! includes stochastic_binary
elif 'binary' in config['class_name']:
self.bits = 1
self.hls_type = XnorPrecisionType()
else:
print("Unsupported quantizer: " + config['class_name'])
self.bits = 16
self.hls_type = FixedPrecisionType(width=16, integer=6, signed=True)

def __call__(self, data):
tf_data = tf.convert_to_tensor(data)
return self.quantizer_fn(tf_data).numpy()
# return self.quantizer_fn(data)


class QKerasBinaryQuantizer:
def __init__(self, config, xnor=False):
self.bits = 1 if xnor else 2
self.hls_type = XnorPrecisionType() if xnor else IntegerPrecisionType(width=2, signed=True)
self.alpha = config['config']['alpha']
# Use the QKeras quantizer to handle any stochastic / alpha stuff
self.quantizer_fn = get_quantizer(config)
# Then we use our BinaryQuantizer to convert to '0,1' format
self.binary_quantizer = BinaryQuantizer(1) if xnor else BinaryQuantizer(2)

def __call__(self, data):
x = tf.convert_to_tensor(data)
y = self.quantizer_fn(x).numpy()
return self.binary_quantizer(y)


class QKerasPO2Quantizer:
def __init__(self, config):
self.bits = config['config']['bits']
self.quantizer_fn = get_quantizer(config)
self.hls_type = ExponentPrecisionType(width=self.bits, signed=True)

def __call__(self, data):
'''
Weights are quantized to nearest power of two
'''
x = tf.convert_to_tensor(data)
y = self.quantizer_fn(x)
if hasattr(y, 'numpy'):
y = y.numpy()
return y


def get_type(quantizer_config):
width = quantizer_config['config']['bits']
integer = quantizer_config['config'].get('integer', 0)
if quantizer_config['class_name'] == 'quantized_po2':
return ExponentPrecisionType(width=width, signed=True)
if width == integer:
if width == 1:
return XnorPrecisionType()
else:
return IntegerPrecisionType(width=width, signed=True)
else:
return FixedPrecisionType(width=width, integer=integer + 1, signed=True)
from hls4ml.converters.keras.convolution import parse_conv1d_layer, parse_conv2d_layer
from hls4ml.converters.keras.core import parse_batchnorm_layer, parse_dense_layer
from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer
from hls4ml.model.types import FixedPrecisionType, QKerasBinaryQuantizer, QKerasPO2Quantizer, QKerasQuantizer


def get_quantizer_from_config(keras_layer, quantizer_var):
Expand All @@ -88,3 +16,130 @@ def get_quantizer_from_config(keras_layer, quantizer_var):
return QKerasPO2Quantizer(quantizer_config)
else:
return QKerasQuantizer(quantizer_config)


@keras_handler('QDense')
def parse_qdense_layer(keras_layer, input_names, input_shapes, data_reader):

layer, output_shape = parse_dense_layer(keras_layer, input_names, input_shapes, data_reader)

layer['weight_quantizer'] = get_quantizer_from_config(keras_layer, 'kernel')
if keras_layer['config']['bias_quantizer'] is not None:
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')
else:
layer['bias_quantizer'] = None

return layer, output_shape


@keras_handler('QConv1D', 'QConv2D')
def parse_qconv_layer(keras_layer, input_names, input_shapes, data_reader):
assert 'QConv' in keras_layer['class_name']

if '1D' in keras_layer['class_name']:
layer, output_shape = parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader)
elif '2D' in keras_layer['class_name']:
layer, output_shape = parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader)

layer['weight_quantizer'] = get_quantizer_from_config(keras_layer, 'kernel')
if keras_layer['config']['bias_quantizer'] is not None:
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')
else:
layer['bias_quantizer'] = None

return layer, output_shape


@keras_handler('QActivation')
def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader):
assert keras_layer['class_name'] == 'QActivation'
supported_activations = [
'quantized_relu',
'quantized_tanh',
'binary_tanh',
'ternary_tanh',
'quantized_sigmoid',
'quantized_bits',
'binary',
'ternary',
]

layer = parse_default_keras_layer(keras_layer, input_names)

activation_config = keras_layer['config']['activation']
quantizer_obj = get_quantizer(activation_config)
activation_config = {}
# some activations are classes
if hasattr(quantizer_obj, 'get_config'):
activation_config['class_name'] = quantizer_obj.__class__.__name__
if activation_config['class_name'] == 'ternary' or activation_config['class_name'] == 'binary':
activation_config['class_name'] += '_tanh'
activation_config['config'] = quantizer_obj.get_config()
# some activation quantizers are just functions with no config
else:
activation_config['config'] = {}
if 'binary' in quantizer_obj.__name__:
activation_config['class_name'] = 'binary_tanh'
activation_config['config']['bits'] = 1
activation_config['config']['integer'] = 1
elif 'ternary' in quantizer_obj.__name__:
activation_config['class_name'] = 'ternary_tanh'
activation_config['config']['bits'] = 2
activation_config['config']['integer'] = 2
else:
activation_config['class_name'] = 'unknown'

if activation_config['class_name'] not in supported_activations:
raise Exception('Unsupported QKeras activation: {}'.format(activation_config['class_name']))

if activation_config['class_name'] == 'quantized_bits':
activation_config['class_name'] = 'linear'

if activation_config['class_name'] == 'ternary_tanh':
layer['class_name'] = 'TernaryTanh'
layer['threshold'] = activation_config.get('config', {}).get('threshold', 0.33)
if layer['threshold'] is None:
layer['threshold'] = 0.33 # the default ternary tanh threshold for QKeras
layer['activation'] = 'ternary_tanh'
elif (
activation_config['class_name'] == 'quantized_sigmoid'
and not activation_config['config'].get('use_real_sigmoid', False)
) or (
activation_config['class_name'] == 'quantized_tanh' and not activation_config['config'].get('use_real_tanh', False)
):
layer['class_name'] = 'HardActivation'
layer['slope'] = 0.5 # the default values in QKeras
layer['shift'] = 0.5
# Quartus seems to have trouble if the width is 1.
layer['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['activation'] = activation_config['class_name'].replace('quantized_', 'hard_')
else:
layer['class_name'] = 'Activation'
layer['activation'] = activation_config['class_name'].replace('quantized_', '')

layer['activation_quantizer'] = activation_config
return layer, [shape for shape in input_shapes[0]]


@keras_handler('QBatchNormalization')
def parse_qbatchnorm_layer(keras_layer, input_names, input_shapes, data_reader):

layer, output_shape = parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader)

layer['mean_quantizer'] = get_quantizer_from_config(keras_layer, 'mean')
layer['variance_quantizer'] = get_quantizer_from_config(keras_layer, 'variance')
layer['beta_quantizer'] = get_quantizer_from_config(keras_layer, 'beta')
layer['gamma_quantizer'] = get_quantizer_from_config(keras_layer, 'gamma')

return layer, output_shape


@keras_handler('QConv2DBatchnorm')
def parse_qconv2dbatchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
intermediate_shape = list()
conv_layer, shape_qconv = parse_qconv_layer(keras_layer, input_names, input_shapes, data_reader)
intermediate_shape.append(shape_qconv)
temp_shape = intermediate_shape
batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader)
return {**conv_layer, **batch_layer}, out_shape
Loading