-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
Hi there, @NEGU93. Thanks for the great effort in making this library. It really accelerate my research in signal recognition task. This TF 2.0 version indeed help me deploy in the edge device with the help of TFlite. However, I found ComplexBatchNormalization() will terribly slow down the training process. Give one example to reproduce:
import numpy as np
from tensorflow.keras.models import Model
import tensorflow
from cvnn.layers import ComplexConv1D, ComplexInput, ComplexDense, ComplexBatchNormalization, ComplexFlatten, complex_input
X_train = np.random.rand(18000, 4096, 2)
Y_train = np.random.randint(0, 9, 18000)
X_test = np.random.rand(2000, 4096, 2)
Y_test = np.random.randint(0, 9, 2000)
inputs = complex_input(shape=X_train.shape[1:])
outs = inputs
outs = (ComplexConv1D(16, 6, strides=1, padding='same', activation='cart_relu'))(outs)
outs = (ComplexBatchNormalization())(outs)
outs = (ComplexConv1D(32, 3, strides=1, padding='same', activation='cart_relu'))(outs)
outs = (ComplexBatchNormalization())(outs)
outs = (ComplexFlatten())(outs)
DL_feature = (ComplexDense(128, activation='cart_relu'))(outs)
outs = (ComplexDense(256, activation='cart_relu'))(DL_feature)
outs = (ComplexDense(256, activation='cart_relu'))(outs)
predictions = (ComplexDense(, activation='cast_to_real'))(outs)
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer=tensorflow.keras.optimizers.Adam(learning_rate=1e-4),
loss=tensorflow.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(X_train, Y_train, batch_size=32, epochs=3, verbose=1, validation_data=(X_test, Y_test),
callbacks=[checkpoint, earlystopping, learn_rate])
It almost cost me 10 mins to train one epoch. But, when I substitute ComplexBatchNormalization() to BatchNormalization(), it only costs me half min. Any ideas?
Metadata
Metadata
Assignees
Labels
No labels