Skip to content
This repository was archived by the owner on Jun 23, 2025. It is now read-only.

Making keras-contrib compatible with tf.keras #387

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
07c17fe
Changed travis
gabrieldemarmiesse Dec 30, 2018
f80d8e0
Added the script to convert
gabrieldemarmiesse Dec 30, 2018
a70847d
fix indentation.
gabrieldemarmiesse Dec 30, 2018
97219b5
Trying fewer changes.
gabrieldemarmiesse Dec 30, 2018
746f9a2
Added the ignore in pytest.
gabrieldemarmiesse Dec 30, 2018
19b0734
fix order of replacement.
gabrieldemarmiesse Dec 30, 2018
a09fabc
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 30, 2018
67924a3
Removed the hacky replace.
gabrieldemarmiesse Dec 30, 2018
49619b0
Used the normal backend for normalize data format.
gabrieldemarmiesse Dec 30, 2018
c251865
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 30, 2018
9724487
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 30, 2018
752f1ef
Making really sure we don't get keras_contrib.
gabrieldemarmiesse Dec 30, 2018
7025a6d
Ignore backend tests for tf.keras.
gabrieldemarmiesse Dec 30, 2018
cc68fa6
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 30, 2018
564af63
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 30, 2018
33058bb
fixed some imports and ignored test for tf_keras.
gabrieldemarmiesse Dec 30, 2018
ee99fa9
used pip install to install tf-keras.
gabrieldemarmiesse Dec 30, 2018
3d167a3
Maybe the option is not at the right place.
gabrieldemarmiesse Dec 30, 2018
dbe92e7
Trying a separate file.
gabrieldemarmiesse Dec 30, 2018
ae27d1b
import the contrib backend.
gabrieldemarmiesse Dec 30, 2018
79bb6d8
Moved to the setup.py
gabrieldemarmiesse Dec 30, 2018
98b507d
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 30, 2018
4ca208a
Let's not use import *
gabrieldemarmiesse Dec 31, 2018
9b6b0dd
Added compute_output_shape.
gabrieldemarmiesse Dec 31, 2018
82aa1dc
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Dec 31, 2018
cbbf06d
Used a more recent version of tf.
gabrieldemarmiesse Dec 31, 2018
9e380e8
Used custom layer to ensure it works with tf.keras.
gabrieldemarmiesse Dec 31, 2018
a0f86a5
Fixing return forgotten.
gabrieldemarmiesse Dec 31, 2018
8a6b5cf
Changed line endings to unix.
gabrieldemarmiesse Dec 31, 2018
543a744
moving to a separate file.
gabrieldemarmiesse Dec 31, 2018
a10f4f0
Added verbose option.
gabrieldemarmiesse Dec 31, 2018
7f28ce0
Made a more robust check.
gabrieldemarmiesse Dec 31, 2018
612c2a3
hopefully all bug fixed.
gabrieldemarmiesse Dec 31, 2018
8a3ffe3
Let's not use the setup.py to convert to tf.keras.
gabrieldemarmiesse Dec 31, 2018
33b8150
Hacked my way through
gabrieldemarmiesse Dec 31, 2018
abafa3a
fixed stupid mistake.
gabrieldemarmiesse Dec 31, 2018
9dbe097
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Jan 10, 2019
9c623f3
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Jan 10, 2019
cb6131f
Some pep8 fixes.
gabrieldemarmiesse Jan 10, 2019
516cc14
Fixed imports in tensorboard.
gabrieldemarmiesse Jan 10, 2019
7427849
Fixed import again.
gabrieldemarmiesse Jan 10, 2019
114be91
Added some xfails.
gabrieldemarmiesse Jan 10, 2019
ac931af
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Jan 10, 2019
fc2dd4d
Fixed the padam error.
gabrieldemarmiesse Jan 10, 2019
4ec19d3
Hopefully fixed tensorboard
gabrieldemarmiesse Jan 10, 2019
edcc433
Adding an xfail.
gabrieldemarmiesse Jan 10, 2019
0cf37de
Added some xfail.
gabrieldemarmiesse Jan 10, 2019
79a3490
Fixed pep8 and removed xfail.
Jan 14, 2019
75db9e9
Merge branch 'master' into tf_keras_again
Jan 14, 2019
f4203ae
Used skipif.
Jan 14, 2019
e04ab3f
Removed useless diff.
Jan 14, 2019
99fb7ba
Merge branch 'master' into tf_keras_again
Jan 18, 2019
a2152c5
Merge branch 'master' into tf_keras_again
Jan 18, 2019
f8058c1
Simplified the base layer.
Jan 18, 2019
bbe1641
Added the back and forth conversion for tf.keras.
Jan 18, 2019
d1a9aaf
Removed Useless function in base_layer.
Jan 18, 2019
2fa39d3
Simplified the hack to work with tensorshapes.
Jan 18, 2019
4e33f67
Fix pep8
Jan 18, 2019
4064e11
Removed small diff.
Jan 18, 2019
0e598b0
Put the script in the setup.py.
Jan 18, 2019
8eec1f8
Removed the import changes from the setup.py.
Jan 18, 2019
b262f9f
Clarified the docstring of `to_tuple`.
Jan 18, 2019
9d02f74
Added install details.
Jan 18, 2019
a15a01b
/bin/bash: q: command not found
gabrieldemarmiesse Jan 20, 2019
95310a2
Typos.
Feb 6, 2019
dd1e3c6
Merge branch 'master' into tf_keras_again
Feb 6, 2019
8757f67
Changed the hard reset to a git stash.
Feb 6, 2019
701f5eb
Adding to_tuple to capsule.py.
Feb 6, 2019
c7505ad
skipping when tf.keras.
Feb 6, 2019
b7c9769
Removing unused imports.
Feb 6, 2019
9fdc858
Merge branch 'master' into tf_keras_again
gabrieldemarmiesse Feb 12, 2019
5e37a62
Merge branch 'tf_keras_again' of github.com:gabrieldemarmiesse/keras-…
gabrieldemarmiesse Feb 12, 2019
7d662fa
Revert some changes.
gabrieldemarmiesse Feb 12, 2019
45318fa
Forgot the to_tuple.
gabrieldemarmiesse Feb 14, 2019
b3366c7
Removed custom objects.
gabrieldemarmiesse Feb 14, 2019
3d7fb42
Removed git reset.
gabrieldemarmiesse Feb 14, 2019
775020c
Added a test.
gabrieldemarmiesse Feb 14, 2019
e637aab
Fixed typo.
gabrieldemarmiesse Feb 14, 2019
b959b8c
Some fixes here and there.
gabrieldemarmiesse Feb 14, 2019
4472d24
Pep8
gabrieldemarmiesse Feb 14, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ matrix:
env: KERAS_BACKEND=tensorflow
- python: 3.6
env: KERAS_BACKEND=tensorflow
- python: 3.6
env: KERAS_BACKEND=tensorflow USE_TF_KERAS=1 PYTEST_IGNORE='--ignore=tests/test_doc_auto_generation.py --ignore=tests/keras_contrib/backend --ignore=tests/keras_contrib/utils/save_load_utils_test.py'
- python: 3.6
env: KERAS_BACKEND=theano THEANO_FLAGS=optimizer=fast_compile
# - python: 3.6
Expand All @@ -37,7 +39,9 @@ install:
- source activate test-environment

- travis_retry pip install --only-binary=numpy,scipy,pandas numpy nose scipy h5py theano mkdocs pytest pytest-pep8 pandas pygithub --progress-bar off
- pip install git+https://github.com/keras-team/keras.git --progress-bar off
- if [[ "$USE_TF_KERAS" == "" ]]; then
pip install git+https://github.com/keras-team/keras.git --progress-bar off;
fi

# set library path
- export LD_LIBRARY_PATH=$HOME/miniconda/envs/test-environment/lib/:$LD_LIBRARY_PATH
Expand All @@ -46,6 +50,9 @@ install:
travis_retry conda install -q mkl mkl-service;
fi

- if [[ "$USE_TF_KERAS" == "1" ]]; then
python convert_to_tf_keras.py;
fi
- pip install -e .[tests] --progress-bar off

# install TensorFlow (CPU version).
Expand All @@ -61,15 +68,15 @@ install:
script:
- export MKL_THREADING_LAYER="GNU"
# run keras backend init to initialize backend config
- python -c "import keras.backend"
- python -c "import keras_contrib.backend"
# create dataset directory to avoid concurrent directory creation at runtime
- mkdir ~/.keras/datasets
# set up keras backend
- sed -i -e 's/"backend":[[:space:]]*"[^"]*/"backend":\ "'$KERAS_BACKEND'/g' ~/.keras/keras.json;
- echo -e "Running tests with the following config:\n$(cat ~/.keras/keras.json)"
- if [[ "$TEST_MODE" == "PEP8_DOC" ]]; then
PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8 -n0 && py.test tests/tooling/ && cd contrib_docs && python autogen.py && mkdocs build;
PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8 -n0 && py.test tests/tooling/ convert_to_tf_keras.py && cd contrib_docs && python autogen.py && mkdocs build;
else
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --ignore=tests/tooling/
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ $PYTEST_IGNORE --ignore=tests/tooling/
--cov-config .coveragerc --cov=keras_contrib tests/;
fi
40 changes: 40 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,46 @@ We love pull requests. Here's a quick guide:

11. Submit your PR. If your changes have been approved in a previous discussion, and if you have complete (and passing) unit tests, your PR is likely to be merged promptly. Otherwise, well...

## About keras-team/keras and tensorflow.keras

This repo supports both keras-team/keras and tensorflow.keras. The way this is done is by changing all the imports in the code by parsing it. This is checked with travis.ci every time you push a commit in a pull request.

There are a number of reasons why your code would work with keras-team/keras but not with tf.keras. The most common is that you use keras' private API. Since both keras are only similar in behavior with respect to their public API, you should only use this. Otherwise it's likely that the function you are using is not in the same place in tf.keras (or does not even exist at all).

Another gotcha is that when creating custom layers and implementing the `build` function, keras-team/keras expects as `input_shape` a tuple of ints. With tf.keras, `input_shape` is a tuple with `Dimensions` objects. This is likely to make the code incompatible. To solve this problem, you should do:

```python
from keras.layers import Layer
from keras_contrib.utils.test_utils import to_tuple


class MyLayer(Layer):
...

def build(self, input_shape):
input_shape = to_tuple(input_shape)
# now `input_shape` is a tuple of ints or None like in keras-team/keras
...
```

To change all the imports in your code to tf.keras to test compatibility, you can do:
```
python convert_to_tf_keras.py
```

To convert your codebase back to keras-team/keras, do:
```
python convert_to_tf_keras.py --revert
```

Note that you are strongly encouraged to commit your code before in case the parsing would go wrong. To discard all the changes you made since the previous commit:
```
# saves a copy of your current codebase in the git stash and comes back to the previous commit
git stash

git stash pop # get your copy back from the git stash if you need to.
```

## A Note for Contributors

Both Keras-Contrib and Keras operate under the [MIT License](LICENSE). At the discretion of the maintainers of both repositories, code may be moved from Keras-Contrib to Keras and vice versa.
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ As the community contributions in Keras-Contrib are tested, used, validated, and
---
## Installation

#### Install keras_contrib for keras-team/keras
For instructions on how to install Keras,
see [the Keras installation page](https://keras.io/#installation).

Expand All @@ -24,6 +25,25 @@ Alternatively, using pip:
sudo pip install git+https://www.github.com/keras-team/keras-contrib.git
```

to uninstall:
```pip
pip uninstall keras_contrib
```

#### Install keras_contrib for tensorflow.keras

```shell
git clone https://www.github.com/keras-team/keras-contrib.git
cd keras-contrib
python convert_to_tf_keras.py
USE_TF_KERAS=1 python setup.py install
```

to uninstall:
```shell
pip uninstall tf_keras_contrib
```

For contributor guidelines see [CONTRIBUTING.md](https://github.com/keras-team/keras-contrib/blob/master/CONTRIBUTING.md)

---
Expand Down
98 changes: 98 additions & 0 deletions convert_to_tf_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import sys

list_conversions = [('import keras.', 'import tensorflow.keras.'),
('import keras ', 'from tensorflow import keras '),
('import keras\n', 'from tensorflow import keras\n'),
('from keras.', 'from tensorflow.keras.'),
('from keras ', 'from tensorflow.keras ')]


def replace_imports_in_text(string, revert):
if revert:
list_imports_to_change = [x[::-1] for x in list_conversions]
else:
list_imports_to_change = list_conversions

text_updated = string
for old_str, new_str in list_imports_to_change:
text_updated = text_updated.replace(old_str, new_str)
return text_updated


def replace_imports_in_file(file_path, revert):
if not file_path.endswith('.py'):
return False
if os.path.abspath(file_path) == os.path.abspath(__file__):
return False
with open(file_path, 'r') as f:
text = f.read()

text_updated = replace_imports_in_text(text, revert)

with open(file_path, 'w+') as f:
f.write(text_updated)

return text_updated != text


def convert_codebase(revert):
nb_of_files_changed = 0
keras_dir = os.path.dirname(os.path.abspath(__file__))
for root, dirs, files in os.walk(keras_dir):
for name in files:
if replace_imports_in_file(os.path.join(root, name), revert):
nb_of_files_changed += 1
print('Changed imports in ' + str(nb_of_files_changed) + ' files.')
print('Those files were found in the directory ' + keras_dir)


def convert_to_tf_keras():
"""Convert the codebase to tf.keras"""
convert_codebase(False)


def convert_to_keras_team_keras():
"""Convert the codebase from tf.keras to keras-team/keras"""
convert_codebase(True)


def test_replace_imports():
python_code = """
import keras
from keras import backend as K
import os
import keras_contrib
import keras_contrib.layers as lay
import keras.layers
from keras.layers import Dense

if K.backend() == 'tensorflow':
import tensorflow as tf
function = tf.max
"""

expected_code = """
from tensorflow import keras
from tensorflow.keras import backend as K
import os
import keras_contrib
import keras_contrib.layers as lay
import tensorflow.keras.layers
from tensorflow.keras.layers import Dense

if K.backend() == 'tensorflow':
import tensorflow as tf
function = tf.max
"""

code_with_replacement = replace_imports_in_text(python_code, False)
assert expected_code == code_with_replacement
assert python_code == replace_imports_in_text(code_with_replacement, True)


if __name__ == '__main__':
if '--revert' in sys.argv:
convert_to_keras_team_keras()
else:
convert_to_tf_keras()
2 changes: 2 additions & 0 deletions keras_contrib/layers/advanced_activations/pelu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.layers import Layer, InputSpec
from keras import initializers, regularizers, constraints
import keras.backend as K
from keras_contrib.utils.test_utils import to_tuple


class PELU(Layer):
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, alpha_initializer='ones',
self.shared_axes = list(shared_axes)

def build(self, input_shape):
input_shape = to_tuple(input_shape)
param_shape = list(input_shape[1:])
self.param_broadcast = [False] * len(param_shape)
if self.shared_axes is not None:
Expand Down
2 changes: 2 additions & 0 deletions keras_contrib/layers/advanced_activations/srelu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.layers import Layer, InputSpec
from keras import initializers
import keras.backend as K
from keras_contrib.utils.test_utils import to_tuple


class SReLU(Layer):
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self, t_left_initializer='zeros',
self.shared_axes = list(shared_axes)

def build(self, input_shape):
input_shape = to_tuple(input_shape)
param_shape = list(input_shape[1:])
self.param_broadcast = [False] * len(param_shape)
if self.shared_axes is not None:
Expand Down
2 changes: 2 additions & 0 deletions keras_contrib/layers/capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras import initializers
from keras import constraints
from keras.layers import Layer
from keras_contrib.utils.test_utils import to_tuple


class Capsule(Layer):
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(self,
self.constraint = constraints.get(constraint)

def build(self, input_shape):
input_shape = to_tuple(input_shape)
input_dim_capsule = input_shape[-1]
if self.share_weights:
self.W = self.add_weight(name='capsule_kernel',
Expand Down
2 changes: 2 additions & 0 deletions keras_contrib/layers/convolutional/cosineconvolution2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.layers import InputSpec
from keras_contrib.utils.conv_utils import conv_output_length
from keras_contrib.utils.conv_utils import normalize_data_format
from keras_contrib.utils.test_utils import to_tuple
import numpy as np


Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(self, filters, kernel_size,
super(CosineConvolution2D, self).__init__(**kwargs)

def build(self, input_shape):
input_shape = to_tuple(input_shape)
if self.data_format == 'channels_first':
stack_size = input_shape[1]
self.kernel_shape = (self.filters, stack_size, self.nb_row, self.nb_col)
Expand Down
2 changes: 2 additions & 0 deletions keras_contrib/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras import constraints
from keras.layers import InputSpec
from keras.layers import Layer
from keras_contrib.utils.test_utils import to_tuple


class CosineDense(Layer):
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(self, units, kernel_initializer='glorot_uniform',
super(CosineDense, self).__init__(**kwargs)

def build(self, input_shape):
input_shape = to_tuple(input_shape)
ndim = len(input_shape)
assert ndim >= 2
input_dim = input_shape[-1]
Expand Down
2 changes: 2 additions & 0 deletions keras_contrib/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_marginal_accuracy
from keras_contrib.metrics import crf_viterbi_accuracy
from keras_contrib.utils.test_utils import to_tuple


class CRF(Layer):
Expand Down Expand Up @@ -247,6 +248,7 @@ def __init__(self, units,
self.unroll = unroll

def build(self, input_shape):
input_shape = to_tuple(input_shape)
self.input_spec = [InputSpec(shape=input_shape)]
self.input_dim = input_shape[-1]

Expand Down
1 change: 0 additions & 1 deletion keras_contrib/optimizers/lars.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras import backend as K
from keras.optimizers import Optimizer
from keras.utils.generic_utils import get_custom_objects


class LARS(Optimizer):
Expand Down
26 changes: 26 additions & 0 deletions keras_contrib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numpy.testing import assert_allclose
import inspect

import keras
from keras.layers import Input
from keras import Model
from keras import backend as K
Expand Down Expand Up @@ -177,3 +178,28 @@ def unpack_singleton(x):
if len(x) == 1:
return x[0]
return x


if keras.__name__ == 'keras':
is_tf_keras = False
elif keras.__name__ == 'tensorflow.keras':
is_tf_keras = True
else:
raise KeyError('Cannot detect if using keras or tf.keras.')


def to_tuple(shape):
"""This functions is here to fix an inconsistency between keras and tf.keras.

In tf.keras, the input_shape argument is an tuple with `Dimensions` objects.
In keras, the input_shape is a simple tuple of ints or `None`.

We'll work with tuples of ints or `None` to be consistent
with keras-team/keras. So we must apply this function to
all input_shapes of the build methods in custom layers.
"""
if is_tf_keras:
import tensorflow as tf
return tuple(tf.TensorShape(shape).as_list())
else:
return shape
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from setuptools import setup
from setuptools import find_packages
import os


setup(name='keras_contrib',
if os.environ.get('USE_TF_KERAS', None) == '1':
name = 'tf_keras_contrib'
install_requires = []
else:
name = 'keras_contrib'
install_requires = ['keras']

setup(name=name,
version='2.0.8',
description='Keras Deep Learning for Python, Community Contributions',
author='Fariz Rahman',
author_email='[email protected]',
url='https://github.com/farizrahman4u/keras-contrib',
license='MIT',
install_requires=['keras'],
install_requires=install_requires,
extras_require={
'h5py': ['h5py'],
'visualize': ['pydot>=1.2.0'],
Expand Down
Loading