@@ -58,11 +58,11 @@ def __jax_array__(self):
58
58
return self .value
59
59
60
60
61
- _JAX_VARIABLE_TYPE = JaxVariable
62
- if config .is_nnx_backend_enabled ():
61
+ Variable = JaxVariable
62
+ if config .is_nnx_enabled ():
63
63
from flax import nnx
64
64
65
- class NnxVariable (KerasVariable , nnx .Variable ):
65
+ class NnxVariable (JaxVariable , nnx .Variable ):
66
66
def __init__ (
67
67
self ,
68
68
initializer ,
@@ -77,43 +77,12 @@ def __init__(
77
77
mutable = None ,
78
78
** nnx_metadata ,
79
79
):
80
- # Determine NNX mutability. This needs to be known for
81
- # nnx.Variable.__init__.
82
- if mutable is None :
83
- actual_nnx_mutable = (
84
- trainable # Keras 'trainable' maps to NNX 'mutable'
85
- )
86
- else :
87
- actual_nnx_mutable = mutable
88
-
89
- # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable'
90
- # param takes precedence.
91
- if "mutable" in nnx_metadata and mutable is not None :
92
- nnx_metadata ["mutable" ] = actual_nnx_mutable
93
- elif "mutable" not in nnx_metadata :
94
- nnx_metadata ["mutable" ] = actual_nnx_mutable
95
-
80
+ nnx_metadata ["mutable" ] = trainable if mutable is None else mutable
96
81
# Initialize nnx.Variable first.
97
82
# Determine the dtype for the placeholder.
98
- _placeholder_value = None
99
- if shape is not None :
100
- if dtype is not None :
101
- _placeholder_value = jnp .zeros (
102
- shape , dtype = standardize_dtype (dtype )
103
- )
104
- else :
105
- _placeholder_value = jnp .zeros (
106
- shape , dtype = standardize_dtype (config .floatx ())
107
- )
108
- else :
109
- if dtype is not None :
110
- _placeholder_value = jnp .array (
111
- 0.0 , dtype = standardize_dtype (dtype )
112
- )
113
- else :
114
- _placeholder_value = jnp .array (
115
- 0.0 , dtype = standardize_dtype (config .floatx ())
116
- )
83
+ _placeholder_value = jnp .zeros (
84
+ shape or (), dtype = standardize_dtype (dtype )
85
+ )
117
86
118
87
# Call nnx.Variable.__init__ directly.
119
88
nnx .Variable .__init__ (
@@ -152,10 +121,10 @@ def __getstate__(self):
152
121
# Get the state from KerasVariable (attributes in __dict__)
153
122
# KerasVariable does not have a custom __getstate__, so we mimic
154
123
# default behavior.
155
- keras_state = self . __dict__ . copy ( )
124
+ keras_state = KerasVariable . __getstate__ ( self )
156
125
157
126
# Get the state from nnx.Variable
158
- nnx_specific_state = super ( KerasVariable , self ). __getstate__ ()
127
+ nnx_specific_state = nnx . Variable . __getstate__ (self )
159
128
160
129
# Merge them. Keras state is primary. NNX specific state adds
161
130
# to it.
@@ -170,10 +139,6 @@ def __getstate__(self):
170
139
"_var_metadata"
171
140
]
172
141
173
- # Remove elements that might be problematic or redundant if
174
- # nnx.Variable's __getstate__
175
- keras_state .pop ("raw_value" , None )
176
-
177
142
return keras_state
178
143
179
144
def __setstate__ (self , state ):
@@ -202,38 +167,20 @@ def __setstate__(self, state):
202
167
203
168
# Ensure Keras's self._value is also consistent with the
204
169
# restored raw_value
205
- object . __setattr__ ( self , " _value" , nnx_raw_value )
170
+ self . _value = nnx_raw_value
206
171
207
172
if hasattr (self , "_shape" ) and self ._shape is not None :
208
173
self ._ndim = len (self ._shape )
209
174
else :
210
175
# Fallback if shape isn't immediately available.
211
176
self ._ndim = len (self .raw_value .shape )
212
177
213
- def _initialize (self , value ):
214
- # Note that variable.shape is needed by distribution_lib
215
- self ._shape = self ._validate_shape (value .shape )
216
- # We can't import the keras/distribution/distribution_lib
217
- # due to circular dependency.
218
- distribution = global_state .get_global_attribute ("distribution" )
219
- if self ._layout is None and distribution is not None :
220
- tensor_layout = distribution .get_variable_layout (self )
221
- from keras .src .distribution import TensorLayout
222
-
223
- if isinstance (tensor_layout , TensorLayout ):
224
- self ._layout = tensor_layout .backend_layout
225
- else :
226
- self ._layout = tensor_layout
227
- self ._direct_assign (value )
228
-
229
178
def _direct_assign (self , value ):
230
179
# Apply JAX-specific distribution if layout is present
231
180
if self ._layout is not None :
232
- processed_value = distribution_lib .distribute_variable (
181
+ value = distribution_lib .distribute_variable (
233
182
value , self ._layout
234
183
)
235
- else :
236
- processed_value = value
237
184
238
185
# Ensure that nnx.Variable part is initialized
239
186
if not hasattr (self , "_var_metadata" ):
@@ -245,48 +192,31 @@ def _direct_assign(self, value):
245
192
hasattr (self , "_var_metadata" )
246
193
and "on_set_value" in self ._var_metadata
247
194
):
248
- final_value = self ._var_metadata ["on_set_value" ](
249
- self , processed_value
250
- )
251
- else :
252
- final_value = processed_value
253
-
254
- # Directly set raw_value. nnx.Variable handles mutable array
255
- # updates
256
- object .__setattr__ (self , "raw_value" , final_value )
257
-
258
- def _convert_to_tensor (self , value , dtype = None ):
259
- return convert_to_tensor (value , dtype = dtype , sparse = False )
260
-
261
- # Overload native accessor.
262
- def __jax_array__ (self ):
263
- return self .value
195
+ value = self ._var_metadata ["on_set_value" ](self , value )
264
196
265
197
@property
266
198
def value (self ):
267
199
if not hasattr (self , "raw_value" ):
268
- if not hasattr (self , "_value" ) or self ._value is None :
269
- if self ._initializer is not None :
270
- initial_value = self ._initializer (
271
- self ._shape , dtype = self ._dtype
272
- )
273
- return self ._maybe_autocast (initial_value )
274
- else :
275
- raise AttributeError (
276
- "Variable is not properly initialized and has"
277
- " no initializer."
278
- )
279
- current_value = self ._value
280
- else :
281
- current_value = self .raw_value
282
- if (
283
- hasattr (self , "_var_metadata" )
284
- and "on_get_value" in self ._var_metadata
285
- ):
286
- current_value = self ._var_metadata ["on_get_value" ](
287
- self , current_value
200
+ if self ._initializer is not None :
201
+ self ._initialize (
202
+ self ._initializer (self .shape , dtype = self .dtype )
203
+ )
204
+ else :
205
+ # This implies nnx.Variable didn't set placeholder or init failed.
206
+ raise AttributeError (
207
+ "Variable is not properly initialized (raw_value missing) "
208
+ "and has no initializer."
288
209
)
210
+ # Now, self.raw_value must exist. It's the source of truth.Add commentMore actions
211
+ current_value = self .raw_value
289
212
213
+ if (
214
+ hasattr (self , "_var_metadata" )
215
+ and "on_get_value" in self ._var_metadata
216
+ ):
217
+ current_value = self ._var_metadata ["on_get_value" ](
218
+ self , current_value
219
+ )
290
220
if in_stateless_scope ():
291
221
scope = get_stateless_scope ()
292
222
stateless_value = scope .get_current_value (self )
@@ -298,7 +228,7 @@ def value(self):
298
228
def __hash__ (self ):
299
229
return id (self )
300
230
301
- _JAX_VARIABLE_TYPE = NnxVariable
231
+ Variable = NnxVariable
302
232
303
233
304
234
def convert_to_tensor (x , dtype = None , sparse = None , ragged = None ):
@@ -314,7 +244,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
314
244
# an existing distributed jax array will raise error.
315
245
return x
316
246
317
- if isinstance (x , _JAX_VARIABLE_TYPE ):
247
+ if isinstance (x , Variable ):
318
248
if dtype is not None and x .dtype != dtype :
319
249
return x .value .astype (dtype )
320
250
return x .value
@@ -598,7 +528,7 @@ def fori_loop(lower, upper, body_fun, init_val):
598
528
599
529
600
530
def stop_gradient (variable ):
601
- if isinstance (variable , _JAX_VARIABLE_TYPE ):
531
+ if isinstance (variable , Variable ):
602
532
variable = variable .value
603
533
return jax .lax .stop_gradient (variable )
604
534
0 commit comments