Skip to content

Commit 99d4307

Browse files
fix test
1 parent c7b2347 commit 99d4307

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras/src/ops/operation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import inspect
22
import textwrap
33

4-
from flax import nnx
5-
64
from keras.src import backend
75
from keras.src import dtype_policies
86
from keras.src import tree
@@ -125,6 +123,8 @@ def __new__(cls, *args, **kwargs):
125123
"""
126124
instance = super(Operation, cls).__new__(cls)
127125
if backend.backend() == "jax" and is_nnx_backend_enabled():
126+
from flax import nnx
127+
128128
vars(instance)["_object__state"] = nnx.object.ObjectState()
129129
# Generate a config to be returned by default by `get_config()`.
130130
arg_names = inspect.getfullargspec(cls.__init__).args

0 commit comments

Comments
 (0)