Skip to content

Commit d2a5153

Browse files
larryliu0820facebook-github-bot
authored andcommitted
See what happens if we export with max_seq_len (#11611)
Summary: With some testing it seems like we can't export with token dimension max being `max_seq_len` in dynamic shape, if we only export with `tokens`. However if we export with both `tokens` and `input_pos`, we can set token dimension max value to be `max_seq_len`. This diff fix 2 things: * Change dynamic shape based on different inputs * Change pte's metadata `get_max_seq_len` and `get_max_context_len` based on the value of token dimension max value in dynamic shape. Differential Revision: D76530379 Pulled By: larryliu0820
1 parent 045b3a5 commit d2a5153

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

extension/llm/export/builder.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,20 @@ def __init__(
133133
self.output_dir = "."
134134
self._saved_pte_filename = None
135135

136+
def __post_init__(self):
137+
"""
138+
Post init function to update metadata based on dynamic shape
139+
"""
140+
dynamic_shape = self._get_dynamic_shape()
141+
if dynamic_shape is not None:
142+
token_dim = dynamic_shape[0][1]
143+
if self.verbose:
144+
logging.info(
145+
f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}"
146+
)
147+
self.metadata["get_max_seq_len"] = token_dim.max
148+
self.metadata["get_max_context_len"] = token_dim.max
149+
136150
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
137151
"""
138152
Set the directory where the .pte file will be saved.
@@ -180,14 +194,19 @@ def _get_dynamic_shape(self) -> Any:
180194
if self.dynamic_shapes:
181195
return self.dynamic_shapes
182196

183-
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
184197
if self.enable_dynamic_shape:
185198
if not self.use_kv_cache:
186199
# Only one input argument: tokens
187-
self.dynamic_shapes = ({1: dim},)
200+
# For some reason if with tokens, we can't go all the way to max_seq_len, otherwise export will fail.
201+
self.dynamic_shapes = (
202+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
203+
)
188204
else:
189205
# Two input arguments: tokens and input_pos but input_pos is static shape
190-
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
206+
self.dynamic_shapes = (
207+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
208+
{"input_pos": {0: 1}},
209+
)
191210
else:
192211
# Two input arguments: tokens and input_pos but both are of static shape
193212
self.dynamic_shapes = None

extension/llm/export/test/test_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_no_kv_cache(self) -> None:
6363
self.assertIsInstance(result[0], dict)
6464
self.assertIn(1, result[0])
6565
# Check that the value at key 1 is a torch.export.Dim with the correct max value
66-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
66+
self.assertEqual(result[0][1].max, self.max_seq_len)
6767

6868
def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> None:
6969
"""Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=True."""
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
91+
self.assertEqual(result[0][1].max, self.max_seq_len)
9292

9393
# Check second element (input_pos dimension)
9494
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)