Skip to content

Commit 6ee44fe

Browse files
committed
Support named match spread for records
We supported `| [...xs]` but by accident not `| {...rest}`. It's a little finicky but seems fine.
1 parent 6c576b7 commit 6ee44fe

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,13 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En
202202
self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}")
203203
updates = {}
204204
for key, pattern_value in pattern.data.items():
205+
assert not isinstance(pattern_value, Spread), "record spread not yet supported"
205206
key_idx = self.record_key(key)
206207
record_value = self._mktemp(f"record_get({arg}, {key_idx})")
207208
self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}")
208209
updates.update(self.try_match(env, record_value, pattern_value, fallthrough))
210+
# TODO(max): Check that there are no other fields in the record,
211+
# perhaps by length check
209212
return updates
210213
raise NotImplementedError("try_match", pattern)
211214

compiler_tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,16 @@ def test_match_int(self) -> None:
4949
def test_match_list(self) -> None:
5050
self.assertEqual(self._run("f [4, 5] . f = | [1, 2] -> 3 | [4, 5] -> 6"), "6\n")
5151

52+
def test_match_list_spread(self) -> None:
53+
self.assertEqual(self._run("f [4, 5] . f = | [_, ...xs] -> xs"), "[5]\n")
54+
5255
def test_match_record(self) -> None:
5356
self.assertEqual(self._run("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6"), "6\n")
5457

58+
@unittest.skip("TODO")
59+
def test_match_record_spread(self) -> None:
60+
self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "[5]\n")
61+
5562
def test_match_hole(self) -> None:
5663
self.assertEqual(self._run("f () . f = | 1 -> 3 | () -> 4"), "4\n")
5764

scrapscript.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,10 +1114,16 @@ def match(obj: Object, pattern: Object) -> Optional[Env]:
11141114
return None
11151115
result: Env = {}
11161116
use_spread = False
1117+
seen_keys: set[str] = set()
11171118
for key, pattern_item in pattern.data.items():
11181119
if isinstance(pattern_item, Spread):
11191120
use_spread = True
1121+
if pattern_item.name is not None:
1122+
assert isinstance(result, dict) # for .update()
1123+
rest_keys = set(obj.data.keys()) - seen_keys
1124+
result.update({pattern_item.name: Record({key: obj.data[key] for key in rest_keys})})
11201125
break
1126+
seen_keys.add(key)
11211127
obj_item = obj.data.get(key)
11221128
if obj_item is None:
11231129
return None
@@ -3300,6 +3306,9 @@ def test_match_record_doubly_binds_vars(self) -> None:
33003306
Int(3),
33013307
)
33023308

3309+
def test_match_record_spread_binds_spread(self) -> None:
3310+
self.assertEqual(self._run("(| { a=1, ...rest } -> rest) {a=1, b=2, c=3}"), Record({"b": Int(2), "c": Int(3)}))
3311+
33033312
def test_match_list_binds_vars(self) -> None:
33043313
self.assertEqual(
33053314
self._run(
@@ -3364,6 +3373,9 @@ def test_match_list_doubly_binds_vars(self) -> None:
33643373
Int(2),
33653374
)
33663375

3376+
def test_match_list_spread_binds_spread(self) -> None:
3377+
self.assertEqual(self._run("(| [x, ...xs] -> xs) [1, 2]"), List([Int(2)]))
3378+
33673379
def test_pipe(self) -> None:
33683380
self.assertEqual(self._run("1 |> (a -> a + 2)"), Int(3))
33693381

0 commit comments

Comments
 (0)