-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathcommon.py
More file actions
465 lines (369 loc) · 14.8 KB
/
common.py
File metadata and controls
465 lines (369 loc) · 14.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
import unittest
import os
import functools
import random
from contextlib import contextmanager, nullcontext
from typing import Callable, Generic, Mapping, Union, Generator, TypeVar, Optional, Any, cast, Type, TypeGuard
from amaranth import *
from amaranth.hdl.ast import Statement
from amaranth.sim import *
from amaranth.sim.core import Command
from transactron.core import SignalBundle, Method, TransactionModule
from transactron.lib import AdapterBase, AdapterTrans
from transactron._utils import def_helper
from coreblocks.utils import ValueLike, HasElaborate, HasDebugSignals, auto_debug_signals, LayoutLike, ModuleConnector
from .gtkw_extension import write_vcd_ext
T = TypeVar("T")
RecordValueDict = Mapping[str, Union[ValueLike, "RecordValueDict"]]
RecordIntDict = Mapping[str, Union[int, "RecordIntDict"]]
RecordIntDictRet = Mapping[str, Any] # full typing hard to work with
TestGen = Generator[Command | Value | Statement | None, Any, T]
_T_nested_collection = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"]
def data_layout(val: int) -> LayoutLike:
return [("data", val)]
def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]:
for name, value in values.items():
if isinstance(value, dict):
yield from set_inputs(value, getattr(field, name))
else:
yield getattr(field, name).eq(value)
def get_outputs(field: Record) -> TestGen[RecordIntDict]:
# return dict of all signal values in a record because amaranth's simulator can't read all
# values of a Record in a single yield - it can only read Values (Signals)
result = {}
for name, _, _ in field.layout:
val = getattr(field, name)
if isinstance(val, Signal):
result[name] = yield val
else: # field is a Record
result[name] = yield from get_outputs(val)
return result
def neg(x: int, xlen: int) -> int:
"""
Computes the negation of a number in the U2 system.
Parameters
----------
x: int
Number in U2 system.
xlen : int
Bit width of x.
Returns
-------
return : int
Negation of x in the U2 system.
"""
return (-x) & (2**xlen - 1)
def int_to_signed(x: int, xlen: int) -> int:
"""
Converts a Python integer into its U2 representation.
Parameters
----------
x: int
Signed Python integer.
xlen : int
Bit width of x.
Returns
-------
return : int
Representation of x in the U2 system.
"""
return x & (2**xlen - 1)
def signed_to_int(x: int, xlen: int) -> int:
"""
Changes U2 representation into Python integer
Parameters
----------
x: int
Number in U2 system.
xlen : int
Bit width of x.
Returns
-------
return : int
Representation of x as signed Python integer.
"""
return x | -(x & (2 ** (xlen - 1)))
def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]:
if isinstance(cont, (list, dict)):
if isinstance(cont, dict):
cont = cont.values()
return all([guard_nested_collection(elem, t) for elem in cont])
elif isinstance(cont, t):
return True
else:
return False
_T_HasElaborate = TypeVar("_T_HasElaborate", bound=HasElaborate)
class SimpleTestCircuit(Elaboratable, Generic[_T_HasElaborate]):
def __init__(self, dut: _T_HasElaborate):
self._dut = dut
self._io: dict[str, _T_nested_collection[TestbenchIO]] = {}
def __getattr__(self, name: str) -> Any:
return self._io[name]
def elaborate(self, platform):
def transform_methods_to_testbenchios(
container: _T_nested_collection[Method],
) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]:
if isinstance(container, list):
tb_list = []
mc_list = []
for elem in container:
tb, mc = transform_methods_to_testbenchios(elem)
tb_list.append(tb)
mc_list.append(mc)
return tb_list, ModuleConnector(*mc_list)
elif isinstance(container, dict):
tb_dict = {}
mc_dict = {}
for name, elem in container.items():
tb, mc = transform_methods_to_testbenchios(elem)
tb_dict[name] = tb
mc_dict[name] = mc
return tb_dict, ModuleConnector(*mc_dict)
else:
tb = TestbenchIO(AdapterTrans(container))
return tb, tb
m = Module()
m.submodules.dut = self._dut
for name, attr in vars(self._dut).items():
if guard_nested_collection(attr, Method) and attr:
tb_cont, mc = transform_methods_to_testbenchios(attr)
self._io[name] = tb_cont
m.submodules[name] = mc
return m
def debug_signals(self):
sigs = {"_dut": auto_debug_signals(self._dut)}
for name, io in self._io.items():
sigs[name] = auto_debug_signals(io)
return sigs
class TestModule(Elaboratable):
def __init__(self, tested_module: HasElaborate, add_transaction_module):
self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module
self.add_transaction_module = add_transaction_module
def elaborate(self, platform) -> HasElaborate:
m = Module()
# so that Amaranth allows us to use add_clock
_dummy = Signal()
m.d.sync += _dummy.eq(1)
m.submodules.tested_module = self.tested_module
return m
class PysimSimulator(Simulator):
def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None):
test_module = TestModule(module, add_transaction_module)
tested_module = test_module.tested_module
super().__init__(test_module)
clk_period = 1e-6
self.add_clock(clk_period)
if isinstance(tested_module, HasDebugSignals):
extra_signals = tested_module.debug_signals
else:
extra_signals = functools.partial(auto_debug_signals, tested_module)
if traces_file:
traces_dir = "test/__traces__"
os.makedirs(traces_dir, exist_ok=True)
# Signal handling is hacky and accesses Simulator internals.
# TODO: try to merge with Amaranth.
if isinstance(extra_signals, Callable):
extra_signals = extra_signals()
clocks = [d.clk for d in cast(Any, self)._fragment.domains.values()]
self.ctx = write_vcd_ext(
cast(Any, self)._engine,
f"{traces_dir}/{traces_file}.vcd",
f"{traces_dir}/{traces_file}.gtkw",
traces=[clocks, extra_signals],
)
else:
self.ctx = nullcontext()
self.deadline = clk_period * max_cycles
def run(self) -> bool:
with self.ctx:
self.run_until(self.deadline)
return not self.advance()
class TestCaseWithSimulator(unittest.TestCase):
@contextmanager
def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True):
traces_file = None
if "__COREBLOCKS_DUMP_TRACES" in os.environ:
traces_file = unittest.TestCase.id(self)
sim = PysimSimulator(
module, max_cycles=max_cycles, add_transaction_module=add_transaction_module, traces_file=traces_file
)
yield sim
res = sim.run()
self.assertTrue(res, "Simulation time limit exceeded")
def tick(self, cycle_cnt=1):
"""
Yields for the given number of cycles.
"""
for _ in range(cycle_cnt):
yield
def random_wait(self, max_cycle_cnt):
"""
Wait for a random amount of cycles in range [1, max_cycle_cnt)
"""
yield from self.tick(random.randrange(max_cycle_cnt))
def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T:
return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg)
class TestbenchIO(Elaboratable):
def __init__(self, adapter: AdapterBase):
self.adapter = adapter
def elaborate(self, platform):
m = Module()
m.submodules += self.adapter
return m
# Low-level operations
def set_enable(self, en) -> TestGen[None]:
yield self.adapter.en.eq(1 if en else 0)
def enable(self) -> TestGen[None]:
yield from self.set_enable(True)
def disable(self) -> TestGen[None]:
yield from self.set_enable(False)
def done(self) -> TestGen[int]:
return (yield self.adapter.done)
def wait_until_done(self) -> TestGen[None]:
while (yield self.adapter.done) != 1:
yield
def set_inputs(self, data: RecordValueDict = {}) -> TestGen[None]:
yield from set_inputs(data, self.adapter.data_in)
def get_outputs(self) -> TestGen[RecordIntDictRet]:
return (yield from get_outputs(self.adapter.data_out))
# Operations for AdapterTrans
def call_init(self, data: RecordValueDict = {}, /, **kwdata: ValueLike | RecordValueDict) -> TestGen[None]:
if data and kwdata:
raise TypeError("call_init() takes either a single dict or keyword arguments")
if not data:
data = kwdata
yield from self.enable()
yield from self.set_inputs(data)
def call_result(self) -> TestGen[Optional[RecordIntDictRet]]:
if (yield from self.done()):
return (yield from self.get_outputs())
return None
def call_do(self) -> TestGen[RecordIntDict]:
while (outputs := (yield from self.call_result())) is None:
yield
yield from self.disable()
return outputs
def call_try(
self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict
) -> TestGen[Optional[RecordIntDictRet]]:
if data and kwdata:
raise TypeError("call_try() takes either a single dict or keyword arguments")
if not data:
data = kwdata
yield from self.call_init(data)
yield
outputs = yield from self.call_result()
yield from self.disable()
return outputs
def call(self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict) -> TestGen[RecordIntDictRet]:
if data and kwdata:
raise TypeError("call() takes either a single dict or keyword arguments")
if not data:
data = kwdata
yield from self.call_init(data)
yield
return (yield from self.call_do())
# Operations for Adapter
def method_argument(self) -> TestGen[Optional[RecordIntDictRet]]:
return (yield from self.call_result())
def method_return(self, data: RecordValueDict = {}) -> TestGen[None]:
yield from self.set_inputs(data)
def method_handle(
self,
function: Callable[..., Optional[RecordIntDict]],
*,
enable: Optional[Callable[[], bool]] = None,
extra_settle_count: int = 0,
) -> TestGen[None]:
enable = enable or (lambda: True)
yield from self.set_enable(enable())
# One extra Settle() required to propagate enable signal.
for _ in range(extra_settle_count + 1):
yield Settle()
while (arg := (yield from self.method_argument())) is None:
yield
yield from self.set_enable(enable())
for _ in range(extra_settle_count + 1):
yield Settle()
ret_out = mock_def_helper(self, function, arg)
yield from self.method_return(ret_out or {})
yield
def method_handle_loop(
self,
function: Callable[..., Optional[RecordIntDict]],
*,
enable: Optional[Callable[[], bool]] = None,
extra_settle_count: int = 0,
) -> TestGen[None]:
yield Passive()
while True:
yield from self.method_handle(function, enable=enable, extra_settle_count=extra_settle_count)
# Debug signals
def debug_signals(self) -> SignalBundle:
return self.adapter.debug_signals()
def def_method_mock(
tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs
) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]:
"""
Decorator function to create method mock handlers. It should be applied on
a function which describes functionality which we want to invoke on method call.
Such function will be wrapped by `method_handle_loop` and called on each
method invocation.
Function `f` should take only one argument `arg` - data used in function
invocation - and should return data to be sent as response to the method call.
Function `f` can also be a method and take two arguments `self` and `arg`,
the data to be passed on to invoke a method. It should return data to be sent
as response to the method call.
Instead of the `arg` argument, the data can be split into keyword arguments.
Make sure to defer accessing state, since decorators are evaluated eagerly
during function declaration.
Parameters
----------
tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO]
Function to get the TestbenchIO providing appropriate `method_handle_loop`.
**kwargs
Arguments passed to `method_handle_loop`.
Example
-------
```
m = TestCircuit()
def target_process(k: int):
@def_method_mock(lambda: m.target[k])
def process(arg):
return {"data": arg["data"] + k}
return process
```
or equivalently
```
m = TestCircuit()
def target_process(k: int):
@def_method_mock(lambda: m.target[k], settle=1, enable=False)
def process(data):
return {"data": data + k}
return process
```
or for class methods
```
@def_method_mock(lambda self: self.target[k], settle=1, enable=False)
def process(self, data):
return {"data": data + k}
```
"""
def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]:
@functools.wraps(func)
def mock(func_self=None, /) -> TestGen[None]:
f = func
getter: Any = tb_getter
kw = kwargs
if func_self is not None:
getter = getter.__get__(func_self)
f = f.__get__(func_self)
kw = {}
for k, v in kwargs.items():
bind = getattr(v, "__get__", None)
kw[k] = bind(func_self) if bind else v
tb = getter()
assert isinstance(tb, TestbenchIO)
yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw)
return mock
return decorator