Skip to content

Commit d7d00f6

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Update metadata max_seq_len based on the max range of dynamic shapes (#11611)
Summary: With some testing it seems like we can't export with token dimension max being `max_seq_len` in dynamic shape, if we only export with `tokens`. However if we export with both `tokens` and `input_pos`, we can set token dimension max value to be `max_seq_len`. This diff fix 2 things: * Change dynamic shape based on different inputs * Change pte's metadata `get_max_seq_len` and `get_max_context_len` based on the value of token dimension max value in dynamic shape. Reviewed By: kimishpatel Differential Revision: D76530379 Pulled By: larryliu0820
1 parent a6d8440 commit d7d00f6

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

extension/llm/export/builder.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ def __init__(
133133
self.output_dir = "."
134134
self._saved_pte_filename = None
135135

136+
def __post_init__(self):
137+
"""
138+
Post init function to update metadata based on dynamic shape
139+
"""
140+
dynamic_shape = self._get_dynamic_shape()
141+
if dynamic_shape is not None:
142+
token_dim = dynamic_shape[0][1]
143+
if self.verbose:
144+
logging.info(
145+
f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}"
146+
)
147+
self.metadata["get_max_seq_len"] = token_dim.max
148+
136149
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
137150
"""
138151
Set the directory where the .pte file will be saved.
@@ -180,14 +193,19 @@ def _get_dynamic_shape(self) -> Any:
180193
if self.dynamic_shapes:
181194
return self.dynamic_shapes
182195

183-
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
184196
if self.enable_dynamic_shape:
185197
if not self.use_kv_cache:
186198
# Only one input argument: tokens
187-
self.dynamic_shapes = ({1: dim},)
199+
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
200+
self.dynamic_shapes = (
201+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
202+
)
188203
else:
189204
# Two input arguments: tokens and input_pos but input_pos is static shape
190-
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
205+
self.dynamic_shapes = (
206+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
207+
{"input_pos": {0: 1}},
208+
)
191209
else:
192210
# Two input arguments: tokens and input_pos but both are of static shape
193211
self.dynamic_shapes = None

0 commit comments

Comments
 (0)