Skip to content

Commit aa6b79d

Browse files
authored
Fix check of unecessary packages (issue #37626) (#37825)
* Fix check of unecessary packages (issue #37626) * Reformat using ruff * And a condition to avoind the risk of matching a random object in `import_utils` * Reformat
1 parent 517367f commit aa6b79d

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/transformers/dynamic_module_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,24 @@ def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
151151
content = f.read()
152152
imported_modules = set()
153153

154+
import transformers.utils
155+
154156
def recursive_look_for_imports(node):
155157
if isinstance(node, ast.Try):
156-
return # Don't recurse into Try blocks and ignore imports in them
158+
return # Don't recurse into Try blocks and ignore imports in them
157159
elif isinstance(node, ast.If):
158160
test = node.test
159161
for condition_node in ast.walk(test):
160-
if isinstance(condition_node, ast.Call) and getattr(condition_node.func, "id", "").startswith(
161-
"is_flash_attn"
162-
):
163-
# Don't recurse into "if flash_attn_available()" blocks and ignore imports in them
164-
return
162+
if isinstance(condition_node, ast.Call):
163+
check_function = getattr(condition_node.func, "id", "")
164+
if (
165+
check_function.endswith("available")
166+
and check_function.startswith("is_flash_attn")
167+
or hasattr(transformers.utils.import_utils, check_function)
168+
):
169+
# Don't recurse into "if flash_attn_available()" or any "if library_available" blocks
170+
# that appears in `transformers.utils.import_utils` and ignore imports in them
171+
return
165172
elif isinstance(node, ast.Import):
166173
# Handle 'import x' statements
167174
for alias in node.names:

0 commit comments

Comments
 (0)