Skip to content

Commit f652cb9

Browse files
address review comments
1 parent d544a0b commit f652cb9

File tree

11 files changed

+53
-130
lines changed

11 files changed

+53
-130
lines changed

guides/distributed_training_with_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from jax.sharding import Mesh
5454
from jax.sharding import NamedSharding
5555
from jax.sharding import PartitionSpec as P
56-
from keras.src.backend.config import is_nnx_backend_enabled
56+
from keras.src.backend.config import is_nnx_enabled
5757
from keras.src.utils.jax_utils import jit
5858
from flax import nnx
5959

@@ -189,7 +189,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y):
189189

190190

191191
# Training step, Keras provides a pure functional optimizer.stateless_apply
192-
@jit()
192+
@jit
193193
def train_step(train_state, x, y):
194194
(
195195
trainable_variables,

keras/api/_tf_keras/keras/config/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20-
from keras.src.backend.config import (
21-
is_nnx_backend_enabled as is_nnx_backend_enabled,
22-
)
20+
from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled
2321
from keras.src.backend.config import max_epochs as max_epochs
2422
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
2523
from keras.src.backend.config import set_epsilon as set_epsilon

keras/api/config/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20-
from keras.src.backend.config import (
21-
is_nnx_backend_enabled as is_nnx_backend_enabled,
22-
)
20+
from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled
2321
from keras.src.backend.config import max_epochs as max_epochs
2422
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
2523
from keras.src.backend.config import set_epsilon as set_epsilon

keras/src/backend/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ def is_flash_attention_enabled():
233233
return global_state.get_global_attribute("flash_attention", default=None)
234234

235235

236-
@keras_export("keras.config.is_nnx_backend_enabled")
237-
def is_nnx_backend_enabled():
236+
@keras_export("keras.config.is_nnx_enabled")
237+
def is_nnx_enabled():
238238
"""Checks whether NNX specific features are enabled for the JAX backend.
239239
240240
Returns:

keras/src/backend/jax/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from keras.src.backend.config import is_nnx_backend_enabled
1+
from keras.src.backend.config import is_nnx_enabled
22
from keras.src.backend.jax import core
33
from keras.src.backend.jax import distribution_lib
44
from keras.src.backend.jax import image
@@ -11,11 +11,7 @@
1111
from keras.src.backend.jax.core import IS_THREAD_SAFE
1212
from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS
1313
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
14-
15-
if is_nnx_backend_enabled():
16-
from keras.src.backend.jax.core import NnxVariable as Variable
17-
else:
18-
from keras.src.backend.jax.core import JaxVariable as Variable
14+
from keras.src.backend.jax.core import Variable as Variable
1915
from keras.src.backend.jax.core import cast
2016
from keras.src.backend.jax.core import compute_output_spec
2117
from keras.src.backend.jax.core import cond

keras/src/backend/jax/core.py

Lines changed: 33 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ def __jax_array__(self):
5858
return self.value
5959

6060

61-
_JAX_VARIABLE_TYPE = JaxVariable
62-
if config.is_nnx_backend_enabled():
61+
Variable = JaxVariable
62+
if config.is_nnx_enabled():
6363
from flax import nnx
6464

65-
class NnxVariable(KerasVariable, nnx.Variable):
65+
class NnxVariable(JaxVariable, nnx.Variable):
6666
def __init__(
6767
self,
6868
initializer,
@@ -77,43 +77,12 @@ def __init__(
7777
mutable=None,
7878
**nnx_metadata,
7979
):
80-
# Determine NNX mutability. This needs to be known for
81-
# nnx.Variable.__init__.
82-
if mutable is None:
83-
actual_nnx_mutable = (
84-
trainable # Keras 'trainable' maps to NNX 'mutable'
85-
)
86-
else:
87-
actual_nnx_mutable = mutable
88-
89-
# Ensure 'mutable' is in nnx_metadata, but explicit 'mutable'
90-
# param takes precedence.
91-
if "mutable" in nnx_metadata and mutable is not None:
92-
nnx_metadata["mutable"] = actual_nnx_mutable
93-
elif "mutable" not in nnx_metadata:
94-
nnx_metadata["mutable"] = actual_nnx_mutable
95-
80+
nnx_metadata["mutable"] = trainable if mutable is None else mutable
9681
# Initialize nnx.Variable first.
9782
# Determine the dtype for the placeholder.
98-
_placeholder_value = None
99-
if shape is not None:
100-
if dtype is not None:
101-
_placeholder_value = jnp.zeros(
102-
shape, dtype=standardize_dtype(dtype)
103-
)
104-
else:
105-
_placeholder_value = jnp.zeros(
106-
shape, dtype=standardize_dtype(config.floatx())
107-
)
108-
else:
109-
if dtype is not None:
110-
_placeholder_value = jnp.array(
111-
0.0, dtype=standardize_dtype(dtype)
112-
)
113-
else:
114-
_placeholder_value = jnp.array(
115-
0.0, dtype=standardize_dtype(config.floatx())
116-
)
83+
_placeholder_value = jnp.zeros(
84+
shape or (), dtype=standardize_dtype(dtype)
85+
)
11786

11887
# Call nnx.Variable.__init__ directly.
11988
nnx.Variable.__init__(
@@ -152,10 +121,10 @@ def __getstate__(self):
152121
# Get the state from KerasVariable (attributes in __dict__)
153122
# KerasVariable does not have a custom __getstate__, so we mimic
154123
# default behavior.
155-
keras_state = self.__dict__.copy()
124+
keras_state = KerasVariable.__getstate__(self)
156125

157126
# Get the state from nnx.Variable
158-
nnx_specific_state = super(KerasVariable, self).__getstate__()
127+
nnx_specific_state = nnx.Variable.__getstate__(self)
159128

160129
# Merge them. Keras state is primary. NNX specific state adds
161130
# to it.
@@ -170,10 +139,6 @@ def __getstate__(self):
170139
"_var_metadata"
171140
]
172141

173-
# Remove elements that might be problematic or redundant if
174-
# nnx.Variable's __getstate__
175-
keras_state.pop("raw_value", None)
176-
177142
return keras_state
178143

179144
def __setstate__(self, state):
@@ -202,38 +167,20 @@ def __setstate__(self, state):
202167

203168
# Ensure Keras's self._value is also consistent with the
204169
# restored raw_value
205-
object.__setattr__(self, "_value", nnx_raw_value)
170+
self._value = nnx_raw_value
206171

207172
if hasattr(self, "_shape") and self._shape is not None:
208173
self._ndim = len(self._shape)
209174
else:
210175
# Fallback if shape isn't immediately available.
211176
self._ndim = len(self.raw_value.shape)
212177

213-
def _initialize(self, value):
214-
# Note that variable.shape is needed by distribution_lib
215-
self._shape = self._validate_shape(value.shape)
216-
# We can't import the keras/distribution/distribution_lib
217-
# due to circular dependency.
218-
distribution = global_state.get_global_attribute("distribution")
219-
if self._layout is None and distribution is not None:
220-
tensor_layout = distribution.get_variable_layout(self)
221-
from keras.src.distribution import TensorLayout
222-
223-
if isinstance(tensor_layout, TensorLayout):
224-
self._layout = tensor_layout.backend_layout
225-
else:
226-
self._layout = tensor_layout
227-
self._direct_assign(value)
228-
229178
def _direct_assign(self, value):
230179
# Apply JAX-specific distribution if layout is present
231180
if self._layout is not None:
232-
processed_value = distribution_lib.distribute_variable(
181+
value = distribution_lib.distribute_variable(
233182
value, self._layout
234183
)
235-
else:
236-
processed_value = value
237184

238185
# Ensure that nnx.Variable part is initialized
239186
if not hasattr(self, "_var_metadata"):
@@ -245,48 +192,31 @@ def _direct_assign(self, value):
245192
hasattr(self, "_var_metadata")
246193
and "on_set_value" in self._var_metadata
247194
):
248-
final_value = self._var_metadata["on_set_value"](
249-
self, processed_value
250-
)
251-
else:
252-
final_value = processed_value
253-
254-
# Directly set raw_value. nnx.Variable handles mutable array
255-
# updates
256-
object.__setattr__(self, "raw_value", final_value)
257-
258-
def _convert_to_tensor(self, value, dtype=None):
259-
return convert_to_tensor(value, dtype=dtype, sparse=False)
260-
261-
# Overload native accessor.
262-
def __jax_array__(self):
263-
return self.value
195+
value = self._var_metadata["on_set_value"](self, value)
264196

265197
@property
266198
def value(self):
267199
if not hasattr(self, "raw_value"):
268-
if not hasattr(self, "_value") or self._value is None:
269-
if self._initializer is not None:
270-
initial_value = self._initializer(
271-
self._shape, dtype=self._dtype
272-
)
273-
return self._maybe_autocast(initial_value)
274-
else:
275-
raise AttributeError(
276-
"Variable is not properly initialized and has"
277-
" no initializer."
278-
)
279-
current_value = self._value
280-
else:
281-
current_value = self.raw_value
282-
if (
283-
hasattr(self, "_var_metadata")
284-
and "on_get_value" in self._var_metadata
285-
):
286-
current_value = self._var_metadata["on_get_value"](
287-
self, current_value
200+
if self._initializer is not None:
201+
self._initialize(
202+
self._initializer(self.shape, dtype=self.dtype)
203+
)
204+
else:
205+
# This implies nnx.Variable didn't set placeholder or init failed.
206+
raise AttributeError(
207+
"Variable is not properly initialized (raw_value missing) "
208+
"and has no initializer."
288209
)
210+
# Now, self.raw_value must exist. It's the source of truth.Add commentMore actions
211+
current_value = self.raw_value
289212

213+
if (
214+
hasattr(self, "_var_metadata")
215+
and "on_get_value" in self._var_metadata
216+
):
217+
current_value = self._var_metadata["on_get_value"](
218+
self, current_value
219+
)
290220
if in_stateless_scope():
291221
scope = get_stateless_scope()
292222
stateless_value = scope.get_current_value(self)
@@ -298,7 +228,7 @@ def value(self):
298228
def __hash__(self):
299229
return id(self)
300230

301-
_JAX_VARIABLE_TYPE = NnxVariable
231+
Variable = NnxVariable
302232

303233

304234
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
@@ -314,7 +244,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
314244
# an existing distributed jax array will raise error.
315245
return x
316246

317-
if isinstance(x, _JAX_VARIABLE_TYPE):
247+
if isinstance(x, Variable):
318248
if dtype is not None and x.dtype != dtype:
319249
return x.value.astype(dtype)
320250
return x.value
@@ -598,7 +528,7 @@ def fori_loop(lower, upper, body_fun, init_val):
598528

599529

600530
def stop_gradient(variable):
601-
if isinstance(variable, _JAX_VARIABLE_TYPE):
531+
if isinstance(variable, Variable):
602532
variable = variable.value
603533
return jax.lax.stop_gradient(variable)
604534

keras/src/backend/jax/core_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import keras
99
from keras.src import backend
1010
from keras.src import testing
11-
from keras.src.backend.config import is_nnx_backend_enabled
11+
from keras.src.backend.config import is_nnx_enabled
1212

13-
if is_nnx_backend_enabled():
13+
if is_nnx_enabled():
1414
from flax import nnx
1515

1616
from keras.src.backend.jax.core import NnxVariable
@@ -21,7 +21,7 @@
2121
reason="JAX backend specific test for core Variable integration with NNX.",
2222
)
2323
@pytest.mark.skipif(
24-
not is_nnx_backend_enabled(),
24+
not is_nnx_enabled(),
2525
reason="Test requires NNX backend to be enabled by default for setup.",
2626
)
2727
class JaxCoreVariableTest(testing.TestCase):

keras/src/backend/jax/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def concatenate(outputs):
234234
return output
235235

236236
if not self.run_eagerly and self.jit_compile:
237-
concatenate = jit()(concatenate)
237+
concatenate = jit(concatenate)
238238

239239
def iterator_step(state, iterator):
240240
data = next(iterator)

keras/src/layers/layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from keras.src.backend.common.name_scope import current_path
3939
from keras.src.backend.common.remat import get_current_remat_mode
4040
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
41-
from keras.src.backend.config import is_nnx_backend_enabled
41+
from keras.src.backend.config import is_nnx_enabled
4242
from keras.src.distribution import distribution_lib
4343
from keras.src.dtype_policies import DTypePolicyMap
4444
from keras.src.layers import input_spec
@@ -54,7 +54,7 @@
5454
if backend.backend() == "tensorflow":
5555
from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer
5656
elif backend.backend() == "jax":
57-
if is_nnx_backend_enabled():
57+
if is_nnx_enabled():
5858
from keras.src.backend.jax.layer import NnxLayer as BackendLayer
5959
else:
6060
from keras.src.backend.jax.layer import JaxLayer as BackendLayer
@@ -1543,10 +1543,11 @@ def __setattr__(self, name, value):
15431543
# NNX-specific bypass for `_called` and `built` attributes
15441544
if (
15451545
backend.backend() == "jax"
1546-
and is_nnx_backend_enabled()
1546+
and is_nnx_enabled()
15471547
and (name == "_called" or name == "built")
15481548
):
15491549
object.__setattr__(self, name, value)
1550+
self._parent_path = current_path()
15501551
return
15511552

15521553
super().__setattr__(

keras/src/ops/operation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from keras.src import tree
77
from keras.src.api_export import keras_export
88
from keras.src.backend.common.keras_tensor import any_symbolic_tensors
9-
from keras.src.backend.config import is_nnx_backend_enabled
9+
from keras.src.backend.config import is_nnx_enabled
1010
from keras.src.ops.node import Node
1111
from keras.src.utils import python_utils
1212
from keras.src.utils import traceback_utils
@@ -122,7 +122,7 @@ def __new__(cls, *args, **kwargs):
122122
to manually implement `get_config()`.
123123
"""
124124
instance = super(Operation, cls).__new__(cls)
125-
if backend.backend() == "jax" and is_nnx_backend_enabled():
125+
if backend.backend() == "jax" and is_nnx_enabled():
126126
from flax import nnx
127127

128128
vars(instance)["_object__state"] = nnx.object.ObjectState()

keras/src/utils/jax_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from keras.src import backend
2-
from keras.src.backend.config import is_nnx_backend_enabled
2+
from keras.src.backend.config import is_nnx_enabled
33

44

55
def is_in_jax_tracing_scope(x=None):
@@ -14,7 +14,7 @@ def is_in_jax_tracing_scope(x=None):
1414

1515
def jit(*args, **kwargs):
1616
def decorator(func):
17-
if is_nnx_backend_enabled():
17+
if is_nnx_enabled():
1818
from flax import nnx
1919

2020
return nnx.jit(func, *args, **kwargs)

0 commit comments

Comments
 (0)