Skip to content

Commit 3b9c01a

Browse files
committed
compiler: Support record spread
Thsi feels even messier than in the interpreter but I am not really sure how else to do this. Maybe with some better static analysis that computes record layouts at compile-time.
1 parent 6ee44fe commit 3b9c01a

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

compiler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,19 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En
201201
if isinstance(pattern, Record):
202202
self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}")
203203
updates = {}
204+
seen_key_indices: list[int] = []
204205
for key, pattern_value in pattern.data.items():
205-
assert not isinstance(pattern_value, Spread), "record spread not yet supported"
206+
if isinstance(pattern_value, Spread):
207+
use_spread = True
208+
if pattern_value.name:
209+
num_seen_keys = len(seen_key_indices)
210+
self._emit(
211+
f"size_t seen_keys[{num_seen_keys}] = {{ {', '.join(map(str, seen_key_indices))} }};"
212+
)
213+
updates[pattern_value.name] = self._mktemp(f"record_rest({arg}, seen_keys, {num_seen_keys})")
214+
break
206215
key_idx = self.record_key(key)
216+
seen_key_indices.append(key_idx)
207217
record_value = self._mktemp(f"record_get({arg}, {key_idx})")
208218
self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}")
209219
updates.update(self.try_match(env, record_value, pattern_value, fallthrough))

compiler_tests.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,17 @@ def test_match_list(self) -> None:
5252
def test_match_list_spread(self) -> None:
5353
self.assertEqual(self._run("f [4, 5] . f = | [_, ...xs] -> xs"), "[5]\n")
5454

55+
def test_match_list_spread_empty(self) -> None:
56+
self.assertEqual(self._run("f [4] . f = | [_, ...xs] -> xs"), "[]\n")
57+
5558
def test_match_record(self) -> None:
5659
self.assertEqual(self._run("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6"), "6\n")
5760

58-
@unittest.skip("TODO")
5961
def test_match_record_spread(self) -> None:
60-
self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "[5]\n")
62+
self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "{b = 2, c = 3}\n")
63+
64+
def test_match_record_spread_empty(self) -> None:
65+
self.assertEqual(self._run("f {a=1} . f = | {a=1, ...rest} -> rest"), "{}\n")
6166

6267
def test_match_hole(self) -> None:
6368
self.assertEqual(self._run("f () . f = | 1 -> 3 | () -> 4"), "4\n")

runtime.c

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,10 @@ struct object* record_get(struct object* record, size_t key) {
471471
return NULL;
472472
}
473473

474+
size_t record_num_fields(struct object* record) {
475+
return as_record(record)->size;
476+
}
477+
474478
bool is_string(struct object* obj) {
475479
if (is_small_string(obj)) {
476480
return true;
@@ -603,6 +607,34 @@ struct object* list_cons(struct object* item, struct object* list) {
603607
return result;
604608
}
605609

610+
bool array_contains(size_t* haystack, size_t size, size_t needle) {
611+
for (size_t i = 0; i < size; i++) {
612+
if (haystack[i] == needle) {
613+
return true;
614+
}
615+
}
616+
return false;
617+
}
618+
619+
struct object* record_rest(struct object* record, size_t* exclude,
620+
size_t num_excluded) {
621+
// NB: This is used in a match expression so it is assumed that all of the
622+
// key indices in the exclude array are present in the record and that there
623+
// are no duplicates in either the record or the exclude array.
624+
HANDLES();
625+
GC_PROTECT(record);
626+
size_t num_keys = record_num_fields(record);
627+
size_t num_result_keys = num_keys - num_excluded;
628+
struct object* result = mkrecord(heap, num_result_keys);
629+
for (size_t src = 0, dst = 0; dst < num_result_keys; src++) {
630+
struct record_field field = as_record(record)->fields[src];
631+
if (!array_contains(exclude, num_excluded, field.key)) {
632+
record_set(result, dst++, field);
633+
}
634+
}
635+
return result;
636+
}
637+
606638
struct object* heap_string_concat(struct object* a, struct object* b) {
607639
uword a_size = string_length(a);
608640
uword b_size = string_length(b);

0 commit comments

Comments
 (0)