Skip to content

Commit 1d426f8

Browse files
yiwenHUUctios
authored andcommitted
fix(tools): fix torch_fx_tool string format
1 parent fea8d47 commit 1d426f8

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

recis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55

6-
__version__ = "1.0.17"
6+
__version__ = "1.0.18"
77

88
pkg_path = os.path.dirname(os.path.realpath(__file__))
99
lib_path = os.path.join(pkg_path, "lib")

recis/utils/torch_fx_tool/export_torch_fx_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def add_dynamic_forward(cls):
365365
func_signature = ", ".join(args)
366366
dict_creation_lines = [f" '{arg}': {arg}," for arg in args]
367367
func_body = (
368-
" inputs = {{\n" + "\n".join(dict_creation_lines) + "\n }\n"
368+
" inputs = {\n" + "\n".join(dict_creation_lines) + "\n }\n"
369369
)
370370
func_body += " return self.user_model(inputs)\n"
371371
func_code = f"def forward(self, {func_signature}):\n{func_body}"
@@ -415,7 +415,7 @@ def add_dynamic_forward(cls):
415415
# 构建函数体
416416
dict_creation_lines = [f" '{arg}': {arg}," for arg in valid_args]
417417
func_body = (
418-
" inputs = {{\n" + "\n".join(dict_creation_lines) + "\n }\n"
418+
" inputs = {\n" + "\n".join(dict_creation_lines) + "\n }\n"
419419
)
420420

421421
# 为invalid_placeholder添加空tensor

0 commit comments

Comments
 (0)