Skip to content

Commit 762857c

Browse files
authored
Add layer id in scope for each TransformerBlock layer (#136)
Add layer id for each layer TransformerBlock scope
1 parent ec66a75 commit 762857c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def forward(
232232
), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match"
233233
end = None if start is None else (start + input_pos) % self.env.cache_len
234234
for layer, cache in zip(self.layers, caches):
235-
with jax.named_scope("TransformerBlock"):
235+
with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)):
236236
h = layer(
237237
h,
238238
freqs_cis,

0 commit comments

Comments
 (0)