Skip to content

Execution errors in "Porting a PyTorch model to JAX" #218

@pavithraes

Description

@pavithraes

Executing the Porting a PyTorch model to JAX tutorial raises the following errors:

  1. UnderMaxVit implementation, calling MaxVit raises AttributeError: 'NoneType' object has no attribute 'value'. I think the fix is: if module.bias.value is not None: --> if module.biasis not None: ?
Traceback
AttributeError                            Traceback (most recent call last)
[/tmp/ipython-input-49-3830862216.py](https://localhost:8080/#) in <cell line: 0>()
      1 x = jnp.ones((4, 224, 224, 3))
      2 
----> 3 mod = MaxVit(
      4     input_size=(224, 224),
      5     stem_channels=64,

4 frames
[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in __call__(cls, *args, **kwargs)
    139 
    140     def __call__(cls, *args: Any, **kwargs: Any) -> Any:
--> 141       return _graph_node_meta_call(cls, *args, **kwargs)
    142 
    143   def _object_meta_construct(cls, self, *args, **kwargs):

[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _graph_node_meta_call(cls, *args, **kwargs)
    148   node = cls.__new__(cls, *args, **kwargs)
    149   vars(node)['_object__state'] = ObjectState()
--> 150   cls._object_meta_construct(node, *args, **kwargs)
    151 
    152   return node

[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _object_meta_construct(cls, self, *args, **kwargs)
    142 
    143   def _object_meta_construct(cls, self, *args, **kwargs):
--> 144     self.__init__(*args, **kwargs)
    145 
    146 

[/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in __init__(self, input_size, stem_channels, partition_size, block_channels, block_layers, head_dim, stochastic_depth_prob, norm_layer, activation_layer, squeeze_ratio, expansion_ratio, mlp_ratio, mlp_dropout, attention_dropout, num_classes, rngs)
    133         )
    134 
--> 135         self._init_weights(rngs)
    136 
    137     def __call__(self, x: jax.Array) -> jax.Array:

[/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in _init_weights(self, rngs)
    148                     rngs(), module.kernel.value.shape, module.kernel.value.dtype
    149                 )
--> 150                 if module.bias.value is not None:
    151                     module.bias.value = jnp.zeros(
    152                         module.bias.value.shape, dtype=module.bias.value.dtype

AttributeError: 'NoneType' object has no attribute 'value'
  1. AssertionErrors in several tests. For example, in Conv2dNormActivation.
Traceback:
  AssertionError                            Traceback (most recent call last)
  [/tmp/ipython-input-68-247890887.py](https://localhost:8080/#) in <cell line: 0>()
        5 nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1)
        6 
  ----> 7 t2f.copy_module(torch_module, nnx_module)
        8 
        9 test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))
  
  3 frames
  [/tmp/ipython-input-67-547272536.py](https://localhost:8080/#) in _copy_params_buffers(self, torch_nn_module, nnx_module)
       83             torch_value = getattr(torch_nn_module, torch_key)
       84             nnx_param = getattr(nnx_module, nnx_key)
  ---> 85             assert nnx_param is not None, (torch_key, nnx_key, nnx_module)
       86 
       87             if torch_value is None:
  
  AssertionError: ('bias', 'bias', Conv( # Param: 18,432 (73.7 KB)
    kernel_shape=(3, 3, 32, 64),
    kernel=Param( # 18,432 (73.7 KB)
      value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
    ),
    bias=None,
    in_features=32,
    out_features=64,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding=((1, 1), (1, 1)),
    input_dilation=1,
    kernel_dilation=(1, 1),
    feature_group_count=1,
    use_bias=False,
    mask=None,
    dtype=None,
    param_dtype=float32,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x7ffaa4b5d3a0>,
    bias_init=<function zeros at 0x7ffaa5b03100>,
    conv_general_dilated=<function conv_general_dilated at 0x7ffaa70cc040>,
    promote_dtype=<function promote_dtype at 0x7ffaa4b5c2c0>
  ))

Very similar errors from _copy_params_buffers in tests: MBConv, MaxVitLayer, MaxVitBlock, and MaxVit

  1. The final prediction scores don't match:
Prediction for the Dog:
- PyTorch model result: ['n02113023', 'Pembroke'], score: 0.7800846099853516
- Flax model result: ['n02113023', 'Pembroke'], score: 0.0008441798854619265

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions