Skip to content

Commit 1483b7f

Browse files
authored
Port ContentAddressableMemory from kuznia-rdzeni/coreblocks#395 (kuznia-rdzeni/coreblocks#573)
1 parent 4168375 commit 1483b7f

8 files changed

Lines changed: 611 additions & 21 deletions

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from datetime import timedelta
2+
from hypothesis import given, settings, Phase
3+
from transactron.testing import *
4+
from transactron.lib.storage import ContentAddressableMemory
5+
6+
7+
class TestContentAddressableMemory(TestCaseWithSimulator):
8+
addr_width = 4
9+
content_width = 5
10+
test_number = 30
11+
nop_number = 3
12+
addr_layout = data_layout(addr_width)
13+
content_layout = data_layout(content_width)
14+
15+
def setUp(self):
16+
self.entries_count = 8
17+
18+
self.circ = SimpleTestCircuit(
19+
ContentAddressableMemory(self.addr_layout, self.content_layout, self.entries_count)
20+
)
21+
22+
self.memory = {}
23+
24+
def generic_process(
25+
self,
26+
method,
27+
input_lst,
28+
behaviour_check=None,
29+
state_change=None,
30+
input_verification=None,
31+
settle_count=0,
32+
name="",
33+
):
34+
def f():
35+
while input_lst:
36+
# wait till all processes will end the previous cycle
37+
yield from self.multi_settle(4)
38+
elem = input_lst.pop()
39+
if isinstance(elem, OpNOP):
40+
yield
41+
continue
42+
if input_verification is not None and not input_verification(elem):
43+
yield
44+
continue
45+
response = yield from method.call(**elem)
46+
yield from self.multi_settle(settle_count)
47+
if behaviour_check is not None:
48+
# Here accesses to circuit are allowed
49+
ret = behaviour_check(elem, response)
50+
if isinstance(ret, Generator):
51+
yield from ret
52+
if state_change is not None:
53+
# It is standard python function by purpose to don't allow accessing circuit
54+
state_change(elem, response)
55+
yield
56+
57+
return f
58+
59+
def push_process(self, in_push):
60+
def verify_in(elem):
61+
return not (frozenset(elem["addr"].items()) in self.memory)
62+
63+
def modify_state(elem, response):
64+
self.memory[frozenset(elem["addr"].items())] = elem["data"]
65+
66+
return self.generic_process(
67+
self.circ.push,
68+
in_push,
69+
state_change=modify_state,
70+
input_verification=verify_in,
71+
settle_count=3,
72+
name="push",
73+
)
74+
75+
def read_process(self, in_read):
76+
def check(elem, response):
77+
addr = elem["addr"]
78+
frozen_addr = frozenset(addr.items())
79+
if frozen_addr in self.memory:
80+
assert response["not_found"] == 0
81+
assert response["data"] == self.memory[frozen_addr]
82+
else:
83+
assert response["not_found"] == 1
84+
85+
return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read")
86+
87+
def remove_process(self, in_remove):
88+
def modify_state(elem, response):
89+
if frozenset(elem["addr"].items()) in self.memory:
90+
del self.memory[frozenset(elem["addr"].items())]
91+
92+
return self.generic_process(self.circ.remove, in_remove, state_change=modify_state, settle_count=2, name="remv")
93+
94+
def write_process(self, in_write):
95+
def verify_in(elem):
96+
ret = frozenset(elem["addr"].items()) in self.memory
97+
return ret
98+
99+
def check(elem, response):
100+
assert response["not_found"] == int(frozenset(elem["addr"].items()) not in self.memory)
101+
102+
def modify_state(elem, response):
103+
if frozenset(elem["addr"].items()) in self.memory:
104+
self.memory[frozenset(elem["addr"].items())] = elem["data"]
105+
106+
return self.generic_process(
107+
self.circ.write,
108+
in_write,
109+
behaviour_check=check,
110+
state_change=modify_state,
111+
input_verification=None,
112+
settle_count=1,
113+
name="writ",
114+
)
115+
116+
@settings(
117+
max_examples=10,
118+
phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.shrink),
119+
derandomize=True,
120+
deadline=timedelta(milliseconds=500),
121+
)
122+
@given(
123+
generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]),
124+
generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]),
125+
generate_process_input(test_number, nop_number, [("addr", addr_layout)]),
126+
generate_process_input(test_number, nop_number, [("addr", addr_layout)]),
127+
)
128+
def test_random(self, in_push, in_write, in_read, in_remove):
129+
with self.reinitialize_fixtures():
130+
self.setUp()
131+
with self.run_simulation(self.circ, max_cycles=500) as sim:
132+
sim.add_sync_process(self.push_process(in_push))
133+
sim.add_sync_process(self.read_process(in_read))
134+
sim.add_sync_process(self.write_process(in_write))
135+
sim.add_sync_process(self.remove_process(in_remove))

test/utils/test_amaranth_ext.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from transactron.testing import *
2+
import random
3+
from transactron.utils.amaranth_ext import MultiPriorityEncoder
4+
5+
6+
class TestMultiPriorityEncoder(TestCaseWithSimulator):
7+
def get_expected(self, input):
8+
places = []
9+
for i in range(self.input_width):
10+
if input % 2:
11+
places.append(i)
12+
input //= 2
13+
places += [None] * self.output_count
14+
return places
15+
16+
def process(self):
17+
for _ in range(self.test_number):
18+
input = random.randrange(2**self.input_width)
19+
yield self.circ.input.eq(input)
20+
yield Settle()
21+
expected_output = self.get_expected(input)
22+
for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids):
23+
if ex is None:
24+
assert (yield valid) == 0
25+
else:
26+
assert (yield valid) == 1
27+
assert (yield real) == ex
28+
yield Delay(1e-7)
29+
30+
@pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24])
31+
@pytest.mark.parametrize("output_count", [1, 3, 4])
32+
def test_random(self, input_width, output_count):
33+
random.seed(input_width + output_count)
34+
self.test_number = 50
35+
self.input_width = input_width
36+
self.output_count = output_count
37+
self.circ = MultiPriorityEncoder(self.input_width, self.output_count)
38+
39+
with self.run_simulation(self.circ) as sim:
40+
sim.add_process(self.process)
41+
42+
@pytest.mark.parametrize("name", ["prio_encoder", None])
43+
def test_static_create_simple(self, name):
44+
random.seed(14)
45+
self.test_number = 50
46+
self.input_width = 7
47+
self.output_count = 1
48+
49+
class DUT(Elaboratable):
50+
def __init__(self, input_width, output_count, name):
51+
self.input = Signal(input_width)
52+
self.output_count = output_count
53+
self.input_width = input_width
54+
self.name = name
55+
56+
def elaborate(self, platform):
57+
m = Module()
58+
out, val = MultiPriorityEncoder.create_simple(m, self.input_width, self.input, name=self.name)
59+
# Save as a list to use common interface in testing
60+
self.outputs = [out]
61+
self.valids = [val]
62+
return m
63+
64+
self.circ = DUT(self.input_width, self.output_count, name)
65+
66+
with self.run_simulation(self.circ) as sim:
67+
sim.add_process(self.process)
68+
69+
@pytest.mark.parametrize("name", ["prio_encoder", None])
70+
def test_static_create(self, name):
71+
random.seed(14)
72+
self.test_number = 50
73+
self.input_width = 7
74+
self.output_count = 2
75+
76+
class DUT(Elaboratable):
77+
def __init__(self, input_width, output_count, name):
78+
self.input = Signal(input_width)
79+
self.output_count = output_count
80+
self.input_width = input_width
81+
self.name = name
82+
83+
def elaborate(self, platform):
84+
m = Module()
85+
out = MultiPriorityEncoder.create(m, self.input_width, self.input, self.output_count, name=self.name)
86+
self.outputs, self.valids = list(zip(*out))
87+
return m
88+
89+
self.circ = DUT(self.input_width, self.output_count, name)
90+
91+
with self.run_simulation(self.circ) as sim:
92+
sim.add_process(self.process)

transactron/lib/storage.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
from transactron.utils.transactron_helpers import from_method_layout, make_layout
55
from ..core import *
6-
from ..utils import SrcLoc, get_src_loc
6+
from ..utils import SrcLoc, get_src_loc, MultiPriorityEncoder
77
from typing import Optional
8-
from transactron.utils import assign, AssignType, LayoutList
8+
from transactron.utils import assign, AssignType, LayoutList, MethodLayout
99
from .reqres import ArgumentsToResultsZipper
1010

11-
__all__ = ["MemoryBank", "AsyncMemoryBank"]
11+
__all__ = ["MemoryBank", "ContentAddressableMemory", "AsyncMemoryBank"]
1212

1313

1414
class MemoryBank(Elaboratable):
@@ -37,7 +37,7 @@ def __init__(
3737
elem_count: int,
3838
granularity: Optional[int] = None,
3939
safe_writes: bool = True,
40-
src_loc: int | SrcLoc = 0
40+
src_loc: int | SrcLoc = 0,
4141
):
4242
"""
4343
Parameters
@@ -138,6 +138,103 @@ def _(arg):
138138
return m
139139

140140

141+
class ContentAddressableMemory(Elaboratable):
142+
"""Content addresable memory
143+
144+
This module implements a content-addressable memory (in short CAM) with Transactron interface.
145+
CAM is a type of memory where instead of predefined indexes there are used values fed in runtime
146+
as keys (similar as in python dictionary). To insert new entry a pair `(key, value)` has to be
147+
provided. Such pair takes an free slot which depends on internal implementation. To read value
148+
a `key` has to be provided. It is compared with every valid key stored in CAM. If there is a hit,
149+
a value is read. There can be many instances of the same key in CAM. In such case it is undefined
150+
which value will be read.
151+
152+
153+
.. warning::
154+
Pushing the value with index already present in CAM is an undefined behaviour.
155+
156+
Attributes
157+
----------
158+
read : Method
159+
Nondestructive read
160+
write : Method
161+
If index present - do update
162+
remove : Method
163+
Remove
164+
push : Method
165+
Inserts new data.
166+
"""
167+
168+
def __init__(self, address_layout: MethodLayout, data_layout: MethodLayout, entries_number: int):
169+
"""
170+
Parameters
171+
----------
172+
address_layout : LayoutLike
173+
The layout of the address records.
174+
data_layout : LayoutLike
175+
The layout of the data.
176+
entries_number : int
177+
The number of slots to create in memory.
178+
"""
179+
self.address_layout = from_method_layout(address_layout)
180+
self.data_layout = from_method_layout(data_layout)
181+
self.entries_number = entries_number
182+
183+
self.read = Method(i=[("addr", self.address_layout)], o=[("data", self.data_layout), ("not_found", 1)])
184+
self.remove = Method(i=[("addr", self.address_layout)])
185+
self.push = Method(i=[("addr", self.address_layout), ("data", self.data_layout)])
186+
self.write = Method(i=[("addr", self.address_layout), ("data", self.data_layout)], o=[("not_found", 1)])
187+
188+
def elaborate(self, platform) -> TModule:
189+
m = TModule()
190+
191+
address_array = Array(
192+
[Signal(self.address_layout, name=f"address_array_{i}") for i in range(self.entries_number)]
193+
)
194+
data_array = Array([Signal(self.data_layout, name=f"data_array_{i}") for i in range(self.entries_number)])
195+
valids = Signal(self.entries_number, name="valids")
196+
197+
m.submodules.encoder_read = encoder_read = MultiPriorityEncoder(self.entries_number, 1)
198+
m.submodules.encoder_write = encoder_write = MultiPriorityEncoder(self.entries_number, 1)
199+
m.submodules.encoder_push = encoder_push = MultiPriorityEncoder(self.entries_number, 1)
200+
m.submodules.encoder_remove = encoder_remove = MultiPriorityEncoder(self.entries_number, 1)
201+
m.d.top_comb += encoder_push.input.eq(~valids)
202+
203+
@def_method(m, self.push, ready=~valids.all())
204+
def _(addr, data):
205+
id = Signal(range(self.entries_number), name="id_push")
206+
m.d.top_comb += id.eq(encoder_push.outputs[0])
207+
m.d.sync += address_array[id].eq(addr)
208+
m.d.sync += data_array[id].eq(data)
209+
m.d.sync += valids.bit_select(id, 1).eq(1)
210+
211+
@def_method(m, self.write)
212+
def _(addr, data):
213+
write_mask = Signal(self.entries_number, name="write_mask")
214+
m.d.top_comb += write_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids)
215+
m.d.top_comb += encoder_write.input.eq(write_mask)
216+
with m.If(write_mask.any()):
217+
m.d.sync += data_array[encoder_write.outputs[0]].eq(data)
218+
return {"not_found": ~write_mask.any()}
219+
220+
@def_method(m, self.read)
221+
def _(addr):
222+
read_mask = Signal(self.entries_number, name="read_mask")
223+
m.d.top_comb += read_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids)
224+
m.d.top_comb += encoder_read.input.eq(read_mask)
225+
return {"data": data_array[encoder_read.outputs[0]], "not_found": ~read_mask.any()}
226+
227+
@def_method(m, self.remove)
228+
def _(addr):
229+
rm_mask = Signal(self.entries_number, name="rm_mask")
230+
m.d.top_comb += rm_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids)
231+
m.d.top_comb += encoder_remove.input.eq(rm_mask)
232+
with m.If(rm_mask.any()):
233+
m.d.sync += valids.bit_select(encoder_remove.outputs[0], 1).eq(0)
234+
235+
return m
236+
237+
141238
class AsyncMemoryBank(Elaboratable):
142239
"""AsyncMemoryBank module.
143240

transactron/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .input_generation import * # noqa: F401
12
from .functions import * # noqa: F401
23
from .infrastructure import * # noqa: F401
34
from .sugar import * # noqa: F401

0 commit comments

Comments
 (0)