Skip to content

Commit f910a54

Browse files
mypyc: add ascii format fast path (#17)
* mypyc: add ascii format fast path * mypyc: extend ascii constant-fold tests
1 parent b5dfda6 commit f910a54

File tree

4 files changed

+79
-9
lines changed

4 files changed

+79
-9
lines changed

mypyc/irbuild/format_str_tokenizer.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from mypy.errors import Errors
1414
from mypy.messages import MessageBuilder
15-
from mypy.nodes import Context, Expression, StrExpr
15+
from mypy.nodes import Context, Expression
1616
from mypy.options import Options
1717
from mypyc.ir.ops import Integer, Value
1818
from mypyc.ir.rtypes import (
@@ -23,9 +23,10 @@
2323
is_str_rprimitive,
2424
)
2525
from mypyc.irbuild.builder import IRBuilder
26+
from mypyc.irbuild.constant_fold import constant_fold_expr
2627
from mypyc.primitives.bytes_ops import bytes_build_op
2728
from mypyc.primitives.int_ops import int_to_str_op
28-
from mypyc.primitives.str_ops import str_build_op, str_op
29+
from mypyc.primitives.str_ops import ascii_op, str_build_op, str_op
2930

3031

3132
@unique
@@ -41,6 +42,7 @@ class FormatOp(Enum):
4142

4243
STR = "s"
4344
INT = "d"
45+
ASCII = "a"
4446
BYTES = "b"
4547

4648

@@ -52,14 +54,25 @@ def generate_format_ops(specifiers: list[ConversionSpecifier]) -> list[FormatOp]
5254
format_ops = []
5355
for spec in specifiers:
5456
# TODO: Match specifiers instead of using whole_seq
55-
if spec.whole_seq == "%s" or spec.whole_seq == "{:{}}":
57+
# Conversion flags for str.format/f-strings (e.g. {!a}); only if no format spec.
58+
if spec.conversion and not spec.format_spec:
59+
if spec.conversion == "!a":
60+
format_op = FormatOp.ASCII
61+
else:
62+
return None
63+
# printf-style tokens and special f-string lowering patterns.
64+
elif spec.whole_seq == "%s" or spec.whole_seq == "{:{}}":
5665
format_op = FormatOp.STR
5766
elif spec.whole_seq == "%d":
5867
format_op = FormatOp.INT
68+
elif spec.whole_seq == "%a":
69+
format_op = FormatOp.ASCII
5970
elif spec.whole_seq == "%b":
6071
format_op = FormatOp.BYTES
72+
# Any other non-empty spec means we can't optimize; fall back to runtime formatting.
6173
elif spec.whole_seq:
6274
return None
75+
# Empty spec ("{}") defaults to str().
6376
else:
6477
format_op = FormatOp.STR
6578
format_ops.append(format_op)
@@ -143,16 +156,23 @@ def convert_format_expr_to_str(
143156
for x, format_op in zip(exprs, format_ops):
144157
node_type = builder.node_type(x)
145158
if format_op == FormatOp.STR:
146-
if is_str_rprimitive(node_type) or isinstance(
147-
x, StrExpr
148-
): # NOTE: why does mypyc think our fake StrExprs are not str rprimitives?
159+
if isinstance(folded := constant_fold_expr(builder, x), str):
160+
var_str = builder.load_literal_value(folded)
161+
elif is_str_rprimitive(node_type):
149162
var_str = builder.accept(x)
150163
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
151164
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)
152165
else:
153166
var_str = builder.primitive_op(str_op, [builder.accept(x)], line)
167+
elif format_op == FormatOp.ASCII:
168+
if (folded := constant_fold_expr(builder, x)) is not None:
169+
var_str = builder.load_literal_value(ascii(folded))
170+
else:
171+
var_str = builder.primitive_op(ascii_op, [builder.accept(x)], line)
154172
elif format_op == FormatOp.INT:
155-
if is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
173+
if isinstance(folded := constant_fold_expr(builder, x), int):
174+
var_str = builder.load_literal_value(str(folded))
175+
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
156176
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)
157177
else:
158178
return None

mypyc/primitives/str_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@
5151
error_kind=ERR_MAGIC,
5252
)
5353

54+
# ascii(obj)
55+
ascii_op = function_op(
56+
name="builtins.ascii",
57+
arg_types=[object_rprimitive],
58+
return_type=str_rprimitive,
59+
c_function_name="PyObject_ASCII",
60+
error_kind=ERR_MAGIC,
61+
)
62+
5463
# translate isinstance(obj, str)
5564
isinstance_str = function_op(
5665
name="builtins.isinstance",
@@ -180,7 +189,7 @@
180189
name="rfind",
181190
arg_types=str_find_types[0 : i + 2],
182191
return_type=int_rprimitive,
183-
c_function_name=str_find_functions[i],
192+
c_function_name=str_rfind_functions[i],
184193
extra_int_constants=str_rfind_constants[i] + [(-1, c_int_rprimitive)],
185194
error_kind=ERR_MAGIC,
186195
)

mypyc/test-data/irbuild-constant-fold.test

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,42 @@ L0:
186186
big5 = r4
187187
return 1
188188

189+
[case testConstantFoldFormatArgs]
190+
# This only tests that the callee and args are constant folded,
191+
# it is not intended to test the result.
192+
from typing import Any, Final
193+
194+
FMT: Final = "{} {}"
195+
FMT_A: Final = "{!a}"
196+
197+
def f() -> str:
198+
return FMT.format(400 + 20, "roll" + "up")
199+
def g() -> str:
200+
return FMT_A.format("\u00e9")
201+
def g2() -> str:
202+
return FMT_A.format("\u2603")
203+
[out]
204+
def f():
205+
r0, r1, r2, r3 :: str
206+
L0:
207+
r0 = CPyTagged_Str(840)
208+
r1 = 'rollup'
209+
r2 = ' '
210+
r3 = CPyStr_Build(3, r0, r2, r1)
211+
return r3
212+
def g():
213+
r0, r1 :: str
214+
L0:
215+
r0 = "'\\xe9'"
216+
r1 = CPyStr_Build(1, r0)
217+
return r1
218+
def g2():
219+
r0, r1 :: str
220+
L0:
221+
r0 = "'\\u2603'"
222+
r1 = CPyStr_Build(1, r0)
223+
return r1
224+
189225
[case testIntConstantFoldingFinal]
190226
from typing import Final
191227
X: Final = 5

mypyc/test-data/irbuild-str.test

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,15 @@ def f(var: Union[str, NewStr], num: int) -> None:
291291
s2 = "I am %d years old." % num
292292
s3 = "Hi! I'm %s. I am %d years old." % (var, num)
293293
s4 = "Float: %f" % num
294+
s5 = "Ascii: %a" % var
294295
[typing fixtures/typing-full.pyi]
295296
[out]
296297
def f(var, num):
297298
var :: str
298299
num :: int
299300
r0, r1, r2, s1, r3, r4, r5, r6, s2, r7, r8, r9, r10, r11, s3, r12 :: str
300301
r13, r14 :: object
301-
r15, s4 :: str
302+
r15, s4, r16, r17, r18, s5 :: str
302303
L0:
303304
r0 = "Hi! I'm "
304305
r1 = '.'
@@ -320,6 +321,10 @@ L0:
320321
r14 = PyNumber_Remainder(r12, r13)
321322
r15 = cast(str, r14)
322323
s4 = r15
324+
r16 = PyObject_ASCII(var)
325+
r17 = 'Ascii: '
326+
r18 = CPyStr_Build(2, r17, r16)
327+
s5 = r18
323328
return 1
324329

325330
[case testDecode]

0 commit comments

Comments
 (0)