AssertionError Traceback (most recent call last)
[/tmp/ipython-input-68-247890887.py](https://localhost:8080/#) in <cell line: 0>()
5 nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1)
6
----> 7 t2f.copy_module(torch_module, nnx_module)
8
9 test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))
3 frames
[/tmp/ipython-input-67-547272536.py](https://localhost:8080/#) in _copy_params_buffers(self, torch_nn_module, nnx_module)
83 torch_value = getattr(torch_nn_module, torch_key)
84 nnx_param = getattr(nnx_module, nnx_key)
---> 85 assert nnx_param is not None, (torch_key, nnx_key, nnx_module)
86
87 if torch_value is None:
AssertionError: ('bias', 'bias', Conv( # Param: 18,432 (73.7 KB)
kernel_shape=(3, 3, 32, 64),
kernel=Param( # 18,432 (73.7 KB)
value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
),
bias=None,
in_features=32,
out_features=64,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)),
input_dilation=1,
kernel_dilation=(1, 1),
feature_group_count=1,
use_bias=False,
mask=None,
dtype=None,
param_dtype=float32,
precision=None,
kernel_init=<function variance_scaling.<locals>.init at 0x7ffaa4b5d3a0>,
bias_init=<function zeros at 0x7ffaa5b03100>,
conv_general_dilated=<function conv_general_dilated at 0x7ffaa70cc040>,
promote_dtype=<function promote_dtype at 0x7ffaa4b5c2c0>
))
Executing the Porting a PyTorch model to JAX tutorial raises the following errors:
MaxVitimplementation, callingMaxVitraisesAttributeError: 'NoneType' object has no attribute 'value'. I think the fix is:if module.bias.value is not None:-->if module.biasis not None:?Traceback
AttributeError Traceback (most recent call last) [/tmp/ipython-input-49-3830862216.py](https://localhost:8080/#) in <cell line: 0>() 1 x = jnp.ones((4, 224, 224, 3)) 2 ----> 3 mod = MaxVit( 4 input_size=(224, 224), 5 stem_channels=64, 4 frames [/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in __call__(cls, *args, **kwargs) 139 140 def __call__(cls, *args: Any, **kwargs: Any) -> Any: --> 141 return _graph_node_meta_call(cls, *args, **kwargs) 142 143 def _object_meta_construct(cls, self, *args, **kwargs): [/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _graph_node_meta_call(cls, *args, **kwargs) 148 node = cls.__new__(cls, *args, **kwargs) 149 vars(node)['_object__state'] = ObjectState() --> 150 cls._object_meta_construct(node, *args, **kwargs) 151 152 return node [/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _object_meta_construct(cls, self, *args, **kwargs) 142 143 def _object_meta_construct(cls, self, *args, **kwargs): --> 144 self.__init__(*args, **kwargs) 145 146 [/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in __init__(self, input_size, stem_channels, partition_size, block_channels, block_layers, head_dim, stochastic_depth_prob, norm_layer, activation_layer, squeeze_ratio, expansion_ratio, mlp_ratio, mlp_dropout, attention_dropout, num_classes, rngs) 133 ) 134 --> 135 self._init_weights(rngs) 136 137 def __call__(self, x: jax.Array) -> jax.Array: [/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in _init_weights(self, rngs) 148 rngs(), module.kernel.value.shape, module.kernel.value.dtype 149 ) --> 150 if module.bias.value is not None: 151 module.bias.value = jnp.zeros( 152 module.bias.value.shape, dtype=module.bias.value.dtype AttributeError: 'NoneType' object has no attribute 'value'AssertionErrors in several tests. For example, inConv2dNormActivation.Traceback:
AssertionError Traceback (most recent call last) [/tmp/ipython-input-68-247890887.py](https://localhost:8080/#) in <cell line: 0>() 5 nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1) 6 ----> 7 t2f.copy_module(torch_module, nnx_module) 8 9 test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46)) 3 frames [/tmp/ipython-input-67-547272536.py](https://localhost:8080/#) in _copy_params_buffers(self, torch_nn_module, nnx_module) 83 torch_value = getattr(torch_nn_module, torch_key) 84 nnx_param = getattr(nnx_module, nnx_key) ---> 85 assert nnx_param is not None, (torch_key, nnx_key, nnx_module) 86 87 if torch_value is None: AssertionError: ('bias', 'bias', Conv( # Param: 18,432 (73.7 KB) kernel_shape=(3, 3, 32, 64), kernel=Param( # 18,432 (73.7 KB) value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32')) ), bias=None, in_features=32, out_features=64, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), input_dilation=1, kernel_dilation=(1, 1), feature_group_count=1, use_bias=False, mask=None, dtype=None, param_dtype=float32, precision=None, kernel_init=<function variance_scaling.<locals>.init at 0x7ffaa4b5d3a0>, bias_init=<function zeros at 0x7ffaa5b03100>, conv_general_dilated=<function conv_general_dilated at 0x7ffaa70cc040>, promote_dtype=<function promote_dtype at 0x7ffaa4b5c2c0> ))Very similar errors from
_copy_params_buffersin tests:MBConv,MaxVitLayer,MaxVitBlock, andMaxVit