diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index c081b3cf..15f4fd04 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -232,7 +232,7 @@ def forward( ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" end = None if start is None else (start + input_pos) % self.env.cache_len for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock"): + with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)): h = layer( h, freqs_cis,