Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions py/objarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,39 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
switch (op) {
case MP_BINARY_OP_MULTIPLY:
case MP_BINARY_OP_INPLACE_MULTIPLY: {
if (!MP_OBJ_IS_INT(rhs_in)) {
return MP_OBJ_NULL; // op not supported
}
mp_uint_t repeat = mp_obj_get_int(rhs_in);
bool inplace = (op == MP_BINARY_OP_INPLACE_MULTIPLY);
mp_buffer_info_t lhs_bufinfo;
array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ);
mp_obj_array_t *res;
byte *ptr;
size_t orig_lhs_bufinfo_len = lhs_bufinfo.len;
if(inplace) {
res = lhs;
size_t item_sz = mp_binary_get_size('@', lhs->typecode, NULL);
lhs->items = m_renew(byte, lhs->items, (lhs->len + lhs->free) * item_sz, lhs->len * repeat * item_sz);
lhs->len = lhs->len * repeat;
lhs->free = 0;
if (!repeat)
return MP_OBJ_FROM_PTR(res);
repeat--;
ptr = (byte*)res->items + orig_lhs_bufinfo_len;
} else {
res = array_new(lhs_bufinfo.typecode, lhs->len * repeat);
ptr = (byte*)res->items;
}
if(orig_lhs_bufinfo_len) {
for(;repeat--; ptr += orig_lhs_bufinfo_len) {
memcpy(ptr, lhs_bufinfo.buf, orig_lhs_bufinfo_len);
}
}
return MP_OBJ_FROM_PTR(res);
}
case MP_BINARY_OP_ADD: {
// allow to add anything that has the buffer protocol (extension to CPython)
mp_buffer_info_t lhs_bufinfo;
Expand Down
28 changes: 28 additions & 0 deletions tests/basics/array_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
try:
import array
except ImportError:
print("SKIP")
raise SystemExit

a1 = array.array('I', [1])
a2 = array.array('I', [2]) * 2
a3 = (a1 + a2)
print(a3)

a3 *= 5
print(a3)

a3 *= 0
print(a3)

a4 = a2 * 0
print(a4)

a4 *= 0
print(a4)

a4 = a4 * 2
print(a4)

a4 *= 2
print(a4)