Running the example from the website:
https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html
Fails in:
ValueError Traceback (most recent call last)
[/tmp/ipython-input-3025996350.py](https://localhost:8080/#) in <cell line: 0>()
----> 1 model = create_model(rngs=nnx.Rngs(0))
2 optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
3 metrics = nnx.MultiMetric(
4 loss=nnx.metrics.Average("loss"),
5 )
18 frames
[/tmp/ipython-input-3890966448.py](https://localhost:8080/#) in create_model(rngs)
183 # Creates the miniGPT model with 4 transformer blocks.
184 def create_model(rngs):
--> 185 return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)
...
[/usr/local/lib/python3.12/dist-packages/flax/core/spmd.py](https://localhost:8080/#) in shard_value(value, sharding_names, sharding_rules, mesh)
38 return value
39 if not mesh and not meta.global_mesh_defined():
---> 40 raise ValueError(
41 'An auto mesh context or metadata is required if creating a variable'
42 f' with annotation {sharding_names=}. '
With the following error:
ValueError: An auto mesh context or metadata is required if creating a variable with annotation sharding_names=NamedSharding(mesh=Mesh('batch': 1, 'model': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device). If running this on CPU for debugging, make a dummy mesh like `jax.make_mesh(((1, 1)), (<your axis names>))`. If running on explicit mode, remove `sharding_names=` annotation.
Running the example from the website:
https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html
Fails in:
With the following error: