Skip to content

Commit 9aca1fa

Browse files
authored
add merge function for NamedDataStore
Differential Revision: D70409078 Pull Request resolved: #8850
1 parent a3bc2f1 commit 9aca1fa

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

exir/_serialize/_named_data_store.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,30 @@ def get_named_data_store_output(self) -> NamedDataStoreOutput:
181181
# Clean up empty maps inside self.external_data
182182
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
183183
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)
184+
185+
def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
186+
"""
187+
Merge another NamedDataStore into this one.
188+
Args:
189+
other (NamedDataStore): the other NamedDataStore to merge.
190+
Raises:
191+
ValueError: when the key exists in both stores, and corresponding
192+
data is different between them.
193+
"""
194+
# Merge the pte_data.
195+
for key, buffer_idx in other.pte_data.items():
196+
self.add_named_data(
197+
key,
198+
other.buffers[buffer_idx].buffer,
199+
other.buffers[buffer_idx].alignment,
200+
)
201+
202+
# Merge the external_data.
203+
for filename, key_to_buffer_idx in other.external_data.items():
204+
for key, buffer_idx in key_to_buffer_idx.items():
205+
self.add_named_data(
206+
key,
207+
other.buffers[buffer_idx].buffer,
208+
other.buffers[buffer_idx].alignment,
209+
external_tag=filename,
210+
)

exir/_serialize/test/test_named_data_store.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,62 @@ def test_add_duplicate_key_fail(self) -> None:
8383
self.assertEqual(len(output.pte_data), 1)
8484
self.assertEqual(output.pte_data["key"], 0)
8585
self.assertEqual(len(output.external_data), 0)
86+
87+
def test_merge(self) -> None:
88+
store1 = NamedDataStore()
89+
store1.add_named_data("key1", b"data1", None, None)
90+
store1.add_named_data("key2", b"data2", 16, "file1")
91+
92+
# Check items in the store1.
93+
output = store1.get_named_data_store_output()
94+
self.assertEqual(len(output.buffers), 2)
95+
self.assertEqual(len(output.pte_data), 1)
96+
self.assertEqual(len(output.external_data), 1)
97+
self.assertEqual(len(output.external_data["file1"]), 1)
98+
99+
store2 = NamedDataStore()
100+
store2.add_named_data("key1", b"data1", None, None)
101+
store2.add_named_data("key3", b"data3", None, None)
102+
store2.add_named_data("key4", b"data4", 16, "file1")
103+
store2.add_named_data("key5", b"data5", 16, "file2")
104+
105+
# Check items in store2.
106+
output2 = store2.get_named_data_store_output()
107+
self.assertEqual(len(output2.buffers), 4)
108+
self.assertEqual(len(output2.pte_data), 2)
109+
self.assertEqual(len(output2.external_data), 2)
110+
self.assertEqual(len(output2.external_data["file1"]), 1)
111+
self.assertEqual(len(output2.external_data["file2"]), 1)
112+
113+
# Merge store2 into store1.
114+
store1.merge_named_data_store(output2)
115+
116+
# Check items in store2 are merged into store1.
117+
output = store1.get_named_data_store_output()
118+
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
119+
self.assertEqual(len(output.buffers), 5)
120+
self.assertEqual(len(output.pte_data), 2)
121+
self.assertEqual(len(output.external_data), 2)
122+
self.assertEqual(len(output.external_data["file1"]), 2)
123+
self.assertEqual(len(output.external_data["file2"]), 1)
124+
125+
def test_merge_duplicate_error(self) -> None:
126+
store1 = NamedDataStore()
127+
store1.add_named_data("key1", b"data1", None, None)
128+
129+
# Check items in the store1.
130+
output = store1.get_named_data_store_output()
131+
self.assertEqual(len(output.buffers), 1)
132+
self.assertEqual(len(output.pte_data), 1)
133+
134+
store2 = NamedDataStore()
135+
store2.add_named_data("key1", b"data2", None, None)
136+
137+
# Check items in store2.
138+
output2 = store2.get_named_data_store_output()
139+
self.assertEqual(len(output2.buffers), 1)
140+
self.assertEqual(len(output2.pte_data), 1)
141+
142+
# Merge store2 into store1 raises error as key1 is already in store1
143+
# with different data.
144+
self.assertRaises(ValueError, store1.merge_named_data_store, output2)

0 commit comments

Comments
 (0)