Skip to content

QKeras bad predictions #437

Closed
Closed
@HenningCode

Description

@HenningCode

Hello Guys,
with the fix from yesterday I tried the example "qkeras_mnist_dense" example model, but the predictions are pretty bad.

I then tried to create my own QKeras model and tried it for a short amount of time on the mnist dataset just to see if thats the also the case with my own model and the results were pretty bad again. Am I doing something wrong?

This is the code to test the example model: (for this I used the latest branch form the Github)

config = hls4ml.utils.fetch_example_model('qkeras_mnist_dense.json')
print_dict(config)
hls_model = hls4ml.converters.keras_to_hls(config)


(x_train, y_train), (x_test, y_test) = mnist.load_data()


RESHAPED = 28*28
NB_CLASSES = 10

x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

x_train = x_train.reshape(x_train.shape[0],RESHAPED)
x_test = x_test.reshape(x_test.shape[0],RESHAPED)

#Tested with and without 
x_train /= 256 
x_test /= 256


y_train = to_categorical(y_train, NB_CLASSES)
y_test = to_categorical(y_test, NB_CLASSES)


hls_model.compile()
print('Ground truth\n',y_test[0:5])
print('HLS Predict:\n',hls_model.predict(x_test[0:5]))

This is the script I am using to test the QKeras Model. If I change the single QDense layer to a normal Dense layer the predictions of the HLS model are pretty close to the predictions of the keras model. Why is that the case? (For this I was using the pip installation)

def count_errors(x,y):
    error = 0
    for i in range(len(x)):
        if(not np.array_equal(x[i],y[i])):
            error += 1

    return error
        

print("============================================================================\n\n\n")

NB_EPOCH = 2
BATCH_SIZE = 64
VERBOSE = 1
NB_CLASSES = 10
OPTIMIZER = Adam(learning_rate=0.0001, decay=0.000025)
VALIDATION_SPLIT = 0.1
BUILDING = 0

(x_train, y_train), (x_test, y_test) = mnist.load_data()

RESHAPED = 784

x_test_orig = x_test

x_train = x_train.astype("float32")
x_test = x_test.astype("float32")


x_train = x_train.reshape(x_train.shape[0], RESHAPED)
x_test = x_test.reshape(x_test.shape[0], RESHAPED)

#Here only tested with, without the accuracy is bad even without QDense
x_train /= 256
x_test /= 256

print('Train shape: ', x_train.shape)
print('Test shape: ', x_test.shape)

y_train = to_categorical(y_train, NB_CLASSES)
y_test = to_categorical(y_test, NB_CLASSES)


x = x_in = Input((RESHAPED,), name="input")
#x = Dense(64,name="dense0")(x)
x = QDense(64,kernel_quantizer=quantized_bits(16,6),
        bias_quantizer=quantized_bits(16,6),name="dense0")(x)
x = Activation("relu", name="act0")(x)
x = Dense(NB_CLASSES,name="dense2")(x)
x = Activation("softmax", name="softmax")(x)

model = Model(inputs=[x_in], outputs=[x])
model.summary()
model.compile(
    loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"])

history = model.fit(
    x_train, y_train, batch_size=BATCH_SIZE,
    epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE,
    validation_split=VALIDATION_SPLIT)

config = hls4ml.utils.config_from_keras_model(model,granularity='name')
config['Model']['Strategy'] = 'Resource'
print_dict(config)
hls_model = hls4ml.converters.convert_from_keras_model(model,
                                            hls_config=config,
                                            output_dir='../output/model_std/hls4ml_prj',
                                            fpga_part='xc7z020clg400-1')
_ = hls_model.compile()

TEST_CASES = 5

out_model = model.predict(x_test[0:TEST_CASES])
out_model_change = np.zeros_like(out_model)
out_model_change[np.arange(len(out_model)), out_model.argmax(1)] = 1

print("Output of Normal Model:\n", out_model)

out_hls = hls_model.predict(x_test[0:TEST_CASES])
out_hls_change = np.zeros_like(out_hls)
out_hls_change[np.arange(len(out_hls)), out_hls.argmax(1)] = 1

print("Output of HLS Model:\n", out_hls)

print('Error Normal: ', count_errors(out_model_change,y_test[0:TEST_CASES]))
print('Error HLS: ', count_errors(out_hls_change,y_test[0:TEST_CASES]))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions