Skip to content

Commit 193fc2f

Browse files
committed
automatically rewrite to shlex.join in --py38-plus
1 parent 908b055 commit 193fc2f

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,15 @@ Availability:
634634
...
635635
```
636636

637+
### shlex.join
638+
639+
Availability:
640+
- `--py38-plus` is passed on the commandline.
641+
642+
```diff
643+
-' '.join(shlex.quote(arg) for arg in cmd)
644+
+shlex.join(cmd)
645+
```
637646

638647
### replace `@functools.lru_cache(maxsize=None)` with shorthand
639648

pyupgrade/_plugins/shlex_join.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import functools
5+
from typing import Iterable
6+
7+
from tokenize_rt import NON_CODING_TOKENS
8+
from tokenize_rt import Offset
9+
from tokenize_rt import Token
10+
11+
from pyupgrade._ast_helpers import ast_to_offset
12+
from pyupgrade._data import register
13+
from pyupgrade._data import State
14+
from pyupgrade._data import TokenFunc
15+
from pyupgrade._token_helpers import find_open_paren
16+
from pyupgrade._token_helpers import find_token
17+
from pyupgrade._token_helpers import victims
18+
19+
20+
def _fix_shlex_join(i: int, tokens: list[Token], *, arg: ast.expr) -> None:
21+
j = find_open_paren(tokens, i)
22+
comp_victims = victims(tokens, j, arg, gen=True)
23+
k = find_token(tokens, comp_victims.arg_index, 'in') + 1
24+
while tokens[k].name in NON_CODING_TOKENS:
25+
k += 1
26+
tokens[comp_victims.ends[0]:comp_victims.ends[-1] + 1] = [Token('OP', ')')]
27+
tokens[i:k] = [Token('CODE', 'shlex.join'), Token('OP', '(')]
28+
29+
30+
@register(ast.Call)
31+
def visit_Call(
32+
state: State,
33+
node: ast.Call,
34+
parent: ast.AST,
35+
) -> Iterable[tuple[Offset, TokenFunc]]:
36+
if state.settings.min_version < (3, 8):
37+
return
38+
39+
if (
40+
isinstance(node.func, ast.Attribute) and
41+
isinstance(node.func.value, ast.Constant) and
42+
isinstance(node.func.value.value, str) and
43+
node.func.attr == 'join' and
44+
not node.keywords and
45+
len(node.args) == 1 and
46+
isinstance(node.args[0], (ast.ListComp, ast.GeneratorExp)) and
47+
isinstance(node.args[0].elt, ast.Call) and
48+
isinstance(node.args[0].elt.func, ast.Attribute) and
49+
isinstance(node.args[0].elt.func.value, ast.Name) and
50+
node.args[0].elt.func.value.id == 'shlex' and
51+
node.args[0].elt.func.attr == 'quote' and
52+
not node.args[0].elt.keywords and
53+
len(node.args[0].elt.args) == 1 and
54+
isinstance(node.args[0].elt.args[0], ast.Name) and
55+
len(node.args[0].generators) == 1 and
56+
isinstance(node.args[0].generators[0].target, ast.Name) and
57+
not node.args[0].generators[0].ifs and
58+
not node.args[0].generators[0].is_async and
59+
node.args[0].elt.args[0].id == node.args[0].generators[0].target.id
60+
):
61+
func = functools.partial(_fix_shlex_join, arg=node.args[0])
62+
yield ast_to_offset(node), func

tests/features/shlex_join_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pyupgrade._data import Settings
6+
from pyupgrade._main import _fix_plugins
7+
8+
9+
@pytest.mark.parametrize(
10+
('s', 'version'),
11+
(
12+
pytest.param(
13+
'from shlex import quote\n'
14+
'" ".join(quote(arg) for arg in cmd)\n',
15+
(3, 8),
16+
id='quote from-imported',
17+
),
18+
pytest.param(
19+
'import shlex\n'
20+
'" ".join(shlex.quote(arg) for arg in cmd)\n',
21+
(3, 7),
22+
id='3.8+ feature',
23+
),
24+
),
25+
)
26+
def test_shlex_join_noop(s, version):
27+
assert _fix_plugins(s, settings=Settings(min_version=version)) == s
28+
29+
30+
@pytest.mark.parametrize(
31+
('s', 'expected'),
32+
(
33+
pytest.param(
34+
'import shlex\n'
35+
'" ".join(shlex.quote(arg) for arg in cmd)\n',
36+
37+
'import shlex\n'
38+
'shlex.join(cmd)\n',
39+
40+
id='generator expression',
41+
),
42+
pytest.param(
43+
'import shlex\n'
44+
'" ".join([shlex.quote(arg) for arg in cmd])\n',
45+
46+
'import shlex\n'
47+
'shlex.join(cmd)\n',
48+
49+
id='list comprehension',
50+
),
51+
pytest.param(
52+
'import shlex\n'
53+
'" ".join([shlex.quote(arg) for arg in cmd],)\n',
54+
55+
'import shlex\n'
56+
'shlex.join(cmd)\n',
57+
58+
id='removes trailing comma',
59+
),
60+
pytest.param(
61+
'import shlex\n'
62+
'" ".join([shlex.quote(arg) for arg in ["a", "b", "c"]],)\n',
63+
64+
'import shlex\n'
65+
'shlex.join(["a", "b", "c"])\n',
66+
67+
id='more complicated iterable',
68+
),
69+
),
70+
)
71+
def test_shlex_join_fixes(s, expected):
72+
assert _fix_plugins(s, settings=Settings(min_version=(3, 8))) == expected

0 commit comments

Comments
 (0)