From de88c263ae232b895be45f6057be13d694630145 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 26 Jun 2024 20:43:26 +0000 Subject: [PATCH] Add layer id for each layer TransformerBlock scope --- jetstream_pt/third_party/llama/model_exportable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index c081b3cf..15f4fd04 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -232,7 +232,7 @@ def forward( ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" end = None if start is None else (start + input_pos) % self.env.cache_len for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock"): + with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)): h = layer( h, freqs_cis,