Skip to content

Running the JAX_for_LLM_pretraining.ipynb in colab fails #265

@Davidnet

Description

@Davidnet

Running the example from the website:

https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html

Fails in:

ValueError                                Traceback (most recent call last)
[/tmp/ipython-input-3025996350.py](https://localhost:8080/#) in <cell line: 0>()
----> 1 model = create_model(rngs=nnx.Rngs(0))
      2 optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
      3 metrics = nnx.MultiMetric(
      4     loss=nnx.metrics.Average("loss"),
      5 )

18 frames
[/tmp/ipython-input-3890966448.py](https://localhost:8080/#) in create_model(rngs)
    183 # Creates the miniGPT model with 4 transformer blocks.
    184 def create_model(rngs):
--> 185     return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)
...
[/usr/local/lib/python3.12/dist-packages/flax/core/spmd.py](https://localhost:8080/#) in shard_value(value, sharding_names, sharding_rules, mesh)
     38     return value
     39   if not mesh and not meta.global_mesh_defined():
---> 40     raise ValueError(
     41       'An auto mesh context or metadata is required if creating a variable'
     42       f' with annotation {sharding_names=}. '

With the following error:

ValueError: An auto mesh context or metadata is required if creating a variable with annotation sharding_names=NamedSharding(mesh=Mesh('batch': 1, 'model': 1, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device). If running this on CPU for debugging, make a dummy mesh like `jax.make_mesh(((1, 1)), (<your axis names>))`. If running on explicit mode, remove `sharding_names=` annotation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions