Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypyc/doc/list_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Operators
* ``lst[n]`` (get item by integer index)
* ``lst[n:m]``, ``lst[n:]``, ``lst[:m]``, ``lst[:]`` (slicing)
* ``lst1 + lst2``, ``lst += iter``
* ``lst * n``, ``n * lst``
* ``lst * n``, ``n * lst``, ``lst *= n``
* ``obj in lst``

Statements
Expand Down
1 change: 1 addition & 0 deletions mypyc/doc/tuple_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Operators
* ``tup[n]`` (integer index)
* ``tup[n:m]``, ``tup[n:]``, ``tup[:m]`` (slicing)
* ``tup1 + tup2``
* ``tup * n``, ``n * tup``

Statements
----------
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ int CPyList_Remove(PyObject *list, PyObject *obj);
CPyTagged CPyList_Index(PyObject *list, PyObject *obj);
PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size);
PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq);
PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size);
PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
PyObject *CPyList_Copy(PyObject *list);
int CPySequence_Check(PyObject *obj);
Expand Down
8 changes: 8 additions & 0 deletions mypyc/lib-rt/list_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq) {
return CPySequence_Multiply(seq, t_size);
}

PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size) {
Py_ssize_t size = CPyTagged_AsSsize_t(t_size);
if (size == -1 && PyErr_Occurred()) {
return NULL;
}
return PySequence_InPlaceRepeat(seq, size);
}

PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
if (likely(PyList_CheckExact(obj)
&& CPyTagged_CheckShort(start) && CPyTagged_CheckShort(end))) {
Expand Down
9 changes: 9 additions & 0 deletions mypyc/primitives/list_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@
error_kind=ERR_MAGIC,
)

# list *= int
binary_op(
name="*=",
arg_types=[list_rprimitive, int_rprimitive],
return_type=list_rprimitive,
c_function_name="CPySequence_InPlaceMultiply",
error_kind=ERR_MAGIC,
)

# list[begin:end]
list_slice_op = custom_op(
arg_types=[list_rprimitive, int_rprimitive, int_rprimitive],
Expand Down
18 changes: 18 additions & 0 deletions mypyc/primitives/tuple_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@
error_kind=ERR_MAGIC,
)

# tuple * int
binary_op(
name="*",
arg_types=[tuple_rprimitive, int_rprimitive],
return_type=tuple_rprimitive,
c_function_name="CPySequence_Multiply",
error_kind=ERR_MAGIC,
)

# int * tuple
binary_op(
name="*",
arg_types=[int_rprimitive, tuple_rprimitive],
return_type=tuple_rprimitive,
c_function_name="CPySequence_RMultiply",
error_kind=ERR_MAGIC,
)

# tuple[begin:end]
tuple_slice_op = custom_op(
arg_types=[tuple_rprimitive, int_rprimitive, int_rprimitive],
Expand Down
3 changes: 3 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __contains__(self, item: object) -> int: ...
def __add__(self, value: Tuple[T_co, ...], /) -> Tuple[T_co, ...]: ...
@overload
def __add__(self, value: Tuple[_T, ...], /) -> Tuple[T_co | _T, ...]: ...
def __mul__(self, value: int, /) -> Tuple[T_co, ...]: ...
def __rmul__(self, value: int, /) -> Tuple[T_co, ...]: ...

class function: pass

Expand All @@ -225,6 +227,7 @@ def __setitem__(self, i: int, o: _T) -> None: pass
def __delitem__(self, i: int) -> None: pass
def __mul__(self, i: int) -> List[_T]: pass
def __rmul__(self, i: int) -> List[_T]: pass
def __imul__(self, i: int) -> List[_T]: ...
def __iter__(self) -> Iterator[_T]: pass
def __len__(self) -> int: pass
def __contains__(self, item: object) -> int: ...
Expand Down
12 changes: 12 additions & 0 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ L0:
b = r4
return 1

[case testListIMultiply]
from typing import List
def f(a: List[int]) -> None:
a *= 2
[out]
def f(a):
a, r0 :: list
L0:
r0 = CPySequence_InPlaceMultiply(a, 4)
a = r0
return 1

[case testListLen]
from typing import List
def f(a: List[int]) -> int:
Expand Down
35 changes: 35 additions & 0 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,38 @@ L0:
r3 = unbox(tuple[int, int, int, int], r2)
c = r3
return 1

[case testTupleMultiply]
from typing import Tuple
def f(a: Tuple[int]) -> None:
b = a * 2
c = 3 * (2,)
def g(a: Tuple[int, ...]) -> None:
b = a * 2
[out]
def f(a):
a :: tuple[int]
r0 :: object
r1 :: tuple
r2, b :: tuple[int, int]
r3 :: tuple[int]
r4 :: object
r5 :: tuple
r6, c :: tuple[int, int, int]
L0:
r0 = box(tuple[int], a)
r1 = CPySequence_Multiply(r0, 4)
r2 = unbox(tuple[int, int], r1)
b = r2
r3 = (4)
r4 = box(tuple[int], r3)
r5 = CPySequence_RMultiply(6, r4)
r6 = unbox(tuple[int, int, int], r5)
c = r6
return 1
def g(a):
a, r0, b :: tuple
L0:
r0 = CPySequence_Multiply(a, 4)
b = r0
return 1
7 changes: 7 additions & 0 deletions mypyc/test-data/run-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,13 @@ def test_add() -> None:
assert in_place_add({3: "", 4: ""}) == res
assert in_place_add(range(3, 5)) == res

def test_multiply() -> None:
l1 = [1]
assert l1 * 3 == [1, 1, 1]
assert 3 * l1 == [1, 1, 1]
l1 *= 3
assert l1 == [1, 1, 1]

[case testOperatorInExpression]

def tuple_in_int0(i: int) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions mypyc/test-data/run-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,12 @@ def test_add() -> None:
assert (1, 2) + (3, 4) == res
with assertRaises(TypeError, 'can only concatenate tuple (not "list") to tuple'):
assert (1, 2) + cast(Any, [3, 4]) == res

def multiply(a: Tuple[Any, ...], b: int) -> Tuple[Any, ...]:
return a * b

def test_multiply() -> None:
res = (1, 1, 1)
assert (1,) * 3 == res
assert 3 * (1,) == res
assert multiply((1,), 3) == res