Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fb5859f
initial s-mode trap impl
qbojj Apr 8, 2026
045657b
add S-mode trap asm tests
qbojj Apr 12, 2026
e5d8da4
add test case for wfi
qbojj Apr 13, 2026
4b5835b
fixup after rebase
qbojj Apr 17, 2026
a9a4b65
cleanup
qbojj Apr 17, 2026
97f3091
cleanup decoder hacks
qbojj Apr 17, 2026
4bd36fe
cleanup interrupt controller
qbojj Apr 17, 2026
2c54fc3
remove unneeded assignment
qbojj Apr 17, 2026
023b479
Merge remote-tracking branch 'upstream/master' into feat-s-mode-traps
qbojj Apr 22, 2026
99bba0c
some PR comment resolutions
qbojj Apr 23, 2026
9c1c41d
Merge branch 'master' into feat-s-mode-traps
qbojj Apr 23, 2026
68973b4
post merge
qbojj Apr 23, 2026
0842d46
add S and U mode to riscof
qbojj Apr 23, 2026
83f49fd
add S and U mode to riscof (spike)
qbojj Apr 23, 2026
8a04485
don't filter privileged tests from riscof
qbojj Apr 24, 2026
b2c197b
improve trap level delegation logic
qbojj Apr 24, 2026
50761c1
Merge remote-tracking branch 'upstream/master' into feat-s-mode-traps
qbojj Apr 24, 2026
0b1939a
restore riscof
qbojj Apr 24, 2026
b0d0bdb
add pmp handling for S-mode tests
qbojj Apr 24, 2026
5c1ad1e
clean test
qbojj Apr 25, 2026
95174f8
format
qbojj Apr 25, 2026
9bc4360
add supervisor traps to default config tests
qbojj Apr 26, 2026
694d036
add supervisor_mode assertion to priv unit
qbojj Apr 28, 2026
96c6fcc
Merge branch 'master' into feat-s-mode-traps
qbojj Apr 28, 2026
7b066c7
add supervisor enable assertion for real
qbojj Apr 28, 2026
48a112e
Merge branch 'master' into feat-s-mode-traps
qbojj Apr 28, 2026
49dddd9
Merge branch 'master' into feat-s-mode-traps
qbojj Apr 28, 2026
f0d116c
Merge branch 'master' into feat-s-mode-traps
tilk Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions coreblocks/backend/retirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
from transactron.lib.metrics import *

from coreblocks.params.genparams import GenParams
from coreblocks.arch import ExceptionCause
from coreblocks.arch import ExceptionCause, PrivilegeLevel
from coreblocks.arch.csr_address import CounterEnableFieldOffsets
from coreblocks.interface.keys import CoreStateKey, CSRInstancesKey, InstructionPrecommitKey
from coreblocks.interface.keys import (
CoreStateKey,
CSRInstancesKey,
InstructionPrecommitKey,
)
from coreblocks.priv.csr.csr_instances import CSRAddress, DoubleCounterCSR, counteren_access_filter
from coreblocks.arch.isa_consts import TrapVectorMode

Expand All @@ -45,8 +49,9 @@ def __init__(
i=gen_params.get(CoreInstructionCounterLayouts).decrement_in,
o=gen_params.get(CoreInstructionCounterLayouts).decrement_out,
)
self.trap_entry = Method()
self.async_interrupt_cause = Method(o=gen_params.get(InternalInterruptControllerLayouts).interrupt_cause)
self.trap_entry = Method(i=[("cause", gen_params.isa.xlen)], o=[("target_priv", PrivilegeLevel)])
interrupt_controller_layouts = gen_params.get(InternalInterruptControllerLayouts)
self.async_interrupt_cause = Method(o=interrupt_controller_layouts.interrupt_cause)
self.checkpoint_tag_free = Method()
self.checkpoint_get_active_tags = Method(o=gen_params.get(RATLayouts).get_active_tags_out)

Expand Down Expand Up @@ -79,7 +84,9 @@ def elaborate(self, platform):

m.submodules += [self.perf_instr_ret, self.perf_trap_latency]

m_csr = self.dependency_manager.get_dependency(CSRInstancesKey()).m_mode
csr_instances = self.dependency_manager.get_dependency(CSRInstancesKey())
m_csr = csr_instances.m_mode
s_csr = csr_instances.s_mode if self.gen_params.supervisor_mode else None
m.submodules.instret_csr = self.instret_csr

side_fx = Signal(init=1)
Expand Down Expand Up @@ -124,6 +131,7 @@ def flush_instr(rob_entry):
continue_pc_override = Signal()
continue_pc = Signal(self.gen_params.isa.xlen)
core_flushing = Signal()
trap_target_priv = Signal(PrivilegeLevel, init=PrivilegeLevel.MACHINE)

with m.FSM("NORMAL") as fsm:
with m.State("NORMAL"):
Expand Down Expand Up @@ -176,11 +184,23 @@ def flush_instr(rob_entry):
m.d.av_comb += cause_entry.eq(cause_register.cause)

with m.If(arch_trap):
# Register RISC-V architectural trap in CSRs
m_csr.mcause.write(m, cause_entry)
m_csr.mepc.write(m, cause_register.pc)
m_csr.mtval.write(m, cause_register.mtval)
self.trap_entry(m)
# Register RISC-V architectural trap in CSRs.
target_priv = self.trap_entry(m, cause=cause_entry).target_priv

def set_trap_csrs(cause_reg, epc_reg, tval_reg):
cause_reg.write(m, cause_entry)
epc_reg.write(m, cause_register.pc)
tval_reg.write(m, cause_register.mtval)

with m.Switch(target_priv):
if self.gen_params.supervisor_mode:
with m.Case(PrivilegeLevel.SUPERVISOR):
assert s_csr is not None
set_trap_csrs(s_csr.scause, s_csr.sepc, s_csr.stval)
with m.Case(PrivilegeLevel.MACHINE):
set_trap_csrs(m_csr.mcause, m_csr.mepc, m_csr.mtval)

m.d.sync += trap_target_priv.eq(target_priv)

# Fetch is already stalled by ExceptionCauseRegister
with m.If(core_empty):
Expand Down Expand Up @@ -229,17 +249,32 @@ def flush_instr(rob_entry):
self.perf_trap_latency.stop(m)

handler_pc = Signal(self.gen_params.isa.xlen)
mtvec_offset = Signal(self.gen_params.isa.xlen)
mtvec_base = m_csr.mtvec_base.read(m).data
mtvec_mode = m_csr.mtvec_mode.read(m).data
mcause = m_csr.mcause.read(m).data
tvec_offset = Signal(self.gen_params.isa.xlen)
tvec_base = Signal(self.gen_params.isa.xlen)
tvec_mode = Signal(TrapVectorMode)
tcause = Signal(self.gen_params.isa.xlen)

def set_vals(reg_base, reg_mode, reg_cause):
m.d.av_comb += [
tvec_base.eq(reg_base.read(m).data),
tvec_mode.eq(reg_mode.read(m).data),
tcause.eq(reg_cause.read(m).data),
]

with m.Switch(trap_target_priv):
if self.gen_params.supervisor_mode:
with m.Case(PrivilegeLevel.SUPERVISOR):
assert s_csr is not None
set_vals(s_csr.stvec_base, s_csr.stvec_mode, s_csr.scause)
with m.Case(PrivilegeLevel.MACHINE):
set_vals(m_csr.mtvec_base, m_csr.mtvec_mode, m_csr.mcause)

# When mode is Vectored, interrupts set pc to base + 4 * cause_number
with m.If(mcause[-1] & (mtvec_mode == TrapVectorMode.VECTORED)):
m.d.av_comb += mtvec_offset.eq(mcause << 2)
with m.If(tcause[-1] & (tvec_mode == TrapVectorMode.VECTORED)):
m.d.av_comb += tvec_offset.eq(tcause << 2)

# (mtvec_base stores base[MXLEN-1:2])
m.d.av_comb += handler_pc.eq((mtvec_base << 2) + mtvec_offset)
# (xtvec_base stores base[MXLEN-1:2])
m.d.av_comb += handler_pc.eq((tvec_base << 2) + tvec_offset)

resume_pc = Mux(continue_pc_override, continue_pc, handler_pc)
m.d.sync += continue_pc_override.eq(0)
Expand Down
16 changes: 11 additions & 5 deletions coreblocks/frontend/decoder/instr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from operator import or_

from amaranth import *
from amaranth.lib import data

from coreblocks.params import *
from coreblocks.arch import *
from coreblocks.arch.optypes import optypes_by_extensions
from .instr_description import instructions_by_optype, Encoding
from coreblocks.interface.layouts import CSRUnitLayouts, PrivUnitLayouts

__all__ = ["InstrDecoder"]

Expand Down Expand Up @@ -322,11 +324,15 @@ def elaborate(self, platform):

# HACK: pass logical registers in unused high bits of CSR instruction for `mtval` reconstruction
with m.If((self.optype == OpType.CSR_REG) | (self.optype == OpType.CSR_IMM)):
m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log : 32].eq(self.rd)
m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log].eq(
self.rs1
)
assert 32 - self.gen_params.isa.reg_cnt_log * 2 >= 5
imm_view = data.View(self.gen_params.get(CSRUnitLayouts).imm_layout, self.imm)
m.d.comb += imm_view.rd.eq(self.rd)
m.d.comb += imm_view.rs1.eq(self.rs1)

# HACK: pass the logical register encoding of SFENCEVMA (same encoding as for CSR instructions for circuit size)
with m.If(self.optype == OpType.SFENCEVMA):
imm_view = data.View(self.gen_params.get(PrivUnitLayouts).sfencevma_imm_layout, self.imm)
m.d.comb += imm_view.rs1.eq(self.rs1)
m.d.comb += imm_view.rs2.eq(self.rs2)

# Instruction simplification

Expand Down
12 changes: 7 additions & 5 deletions coreblocks/func_blocks/csr/csr_unit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from amaranth import *
from amaranth.lib.data import StructLayout
from amaranth.lib.data import StructLayout, View

from dataclasses import dataclass

Expand Down Expand Up @@ -226,17 +226,19 @@ def _():
with m.If(exception):
mtval = Signal(self.gen_params.isa.xlen)
# re-encode the CSR instruction to speed-up missing CSR emulation (optional, otherwise mtval must be 0)
imm_view = View(self.csr_layouts.imm_layout, instr.imm)

m.d.av_comb += mtval[0:2].eq(0b11)
m.d.av_comb += mtval[2:7].eq(Opcode.SYSTEM)
m.d.av_comb += mtval[7:12].eq(instr.imm[32 - self.gen_params.isa.reg_cnt_log : 32]) # rl_rd
m.d.av_comb += mtval[7:12].eq(imm_view.rd)
m.d.av_comb += mtval[12:15].eq(instr.exec_fn.funct3)
m.d.av_comb += mtval[15:20].eq(
Mux(
instr.exec_fn.op_type == OpType.CSR_IMM,
instr.imm[0:5],
instr.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log],
imm_view.imm,
imm_view.rs1,
)
) # rl_s1 or imm
)
m.d.av_comb += mtval[20:32].eq(instr.csr)
self.report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc, mtval=mtval)
with m.Elif(interrupt):
Expand Down
2 changes: 2 additions & 0 deletions coreblocks/func_blocks/fu/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _(arg):
with m.Switch(priv_level):
with m.Case(PrivilegeLevel.MACHINE):
m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M)
with m.Case(PrivilegeLevel.SUPERVISOR):
m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_S)
with m.Case(PrivilegeLevel.USER):
m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U)
m.d.av_comb += mtval.eq(0) # by SPEC
Expand Down
100 changes: 85 additions & 15 deletions coreblocks/func_blocks/fu/priv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from dataclasses import dataclass, KW_ONLY, field
from amaranth import *
from amaranth.lib import data

from enum import IntFlag, auto, unique
from typing import Sequence
from coreblocks.arch.isa_consts import Funct12, Funct3, Opcode, PrivilegeLevel
from coreblocks.arch.isa_consts import Funct12, Funct3, Funct7, Opcode, PrivilegeLevel, SatpMode


from transactron import *
Expand All @@ -15,8 +16,10 @@
from coreblocks.params import *
from coreblocks.params import GenParams, FunctionalComponentParams
from coreblocks.arch import OpType, ExceptionCause
from coreblocks.interface.layouts import PrivUnitLayouts
from coreblocks.interface.keys import (
MretKey,
SretKey,
AsyncInterruptInsertSignalKey,
ExceptionReportKey,
CSRInstancesKey,
Expand All @@ -34,14 +37,26 @@


class PrivilegedFn(DecoderManager):
def __init__(self, supervisor_enable=False) -> None:
self.supervisor_enable = supervisor_enable

@unique
class Fn(IntFlag):
MRET = auto()
FENCEI = auto()
WFI = auto()
SRET = auto()
SFENCEVMA = auto()

def get_instructions(self) -> Sequence[tuple]:
return [(self.Fn.MRET, OpType.MRET), (self.Fn.FENCEI, OpType.FENCEI), (self.Fn.WFI, OpType.WFI)]
return [
(self.Fn.MRET, OpType.MRET),
(self.Fn.FENCEI, OpType.FENCEI),
(self.Fn.WFI, OpType.WFI),
] + [
(self.Fn.SRET, OpType.SRET),
(self.Fn.SFENCEVMA, OpType.SFENCEVMA),
] * self.supervisor_enable


class PrivilegedFuncUnit(FuncUnitBase[PrivilegedFn]):
Expand Down Expand Up @@ -71,7 +86,12 @@ def elaborate(self, platform):
instr_pc = Signal(self.gen_params.isa.xlen)
instr_fn = self.fn.get_function()

instr_imm = Signal(self.gen_params.isa.xlen)
instr_s1_val = Signal(self.gen_params.isa.xlen)
instr_s2_val = Signal(self.gen_params.isa.xlen)

mret = self.dm.get_dependency(MretKey())
sret = self.dm.get_optional_dependency(SretKey())
async_interrupt_active = self.dm.get_dependency(AsyncInterruptInsertSignalKey())
wfi_resume = self.dm.get_dependency(WaitForInterruptResumeKey())
csr = self.dm.get_dependency(CSRInstancesKey())
Expand All @@ -86,6 +106,9 @@ def _(arg):
instr_rob.eq(arg.rob_id),
instr_pc.eq(arg.pc),
instr_fn.eq(arg.decode_fn),
instr_s1_val.eq(arg.s1_val),
instr_s2_val.eq(arg.s2_val),
instr_imm.eq(arg.imm),
]

with Transaction().body(m, ready=instr_valid & ~finished):
Expand All @@ -97,24 +120,47 @@ def _(arg):
priv_data = priv_mode.read(m).data

illegal_mret = (instr_fn == PrivilegedFn.Fn.MRET) & (priv_data != PrivilegeLevel.MACHINE)
# future todo: WFI should be illegal in U-Mode only if S-Mode is supported
illegal_wfi = (
(instr_fn == PrivilegedFn.Fn.WFI)
& (priv_data == PrivilegeLevel.USER)
& csr.m_mode.mstatus_tw.read(m).data

if self.fn.supervisor_enable:
illegal_sret = (instr_fn == PrivilegedFn.Fn.SRET) & (
(priv_data == PrivilegeLevel.USER)
| ((priv_data == PrivilegeLevel.SUPERVISOR) & csr.m_mode.mstatus_tsr.read(m).data)
)
else:
illegal_sret = 0

if self.fn.supervisor_enable:
illegal_sfencevma = (instr_fn == PrivilegedFn.Fn.SFENCEVMA) & (
(priv_data == PrivilegeLevel.USER)
| ((priv_data == PrivilegeLevel.SUPERVISOR) & csr.m_mode.mstatus_tvm.read(m).data)
)
else:
illegal_sfencevma = 0

illegal_wfi = (instr_fn == PrivilegedFn.Fn.WFI) & (
((priv_data == PrivilegeLevel.USER) if self.gen_params.supervisor_mode else 0)
| ((priv_data < PrivilegeLevel.MACHINE) & (csr.m_mode.mstatus_tw.read(m).data))
)

with condition(m, nonblocking=True) as branch:
with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.MRET) & ~illegal_mret):
mret(m)
if self.fn.supervisor_enable:
assert sret is not None
with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.SRET) & ~illegal_sret):
sret(m)

# TODO: implement proper SFENCE.VMA, for BARE only - NO-OP is ok
assert self.gen_params.vmem_params.supported_schemes == {SatpMode.BARE}

with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.FENCEI)):
flush_icache(m)
with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.WFI) & ~illegal_wfi):
# async_interrupt_active implies wfi_resume. WFI should continue normal execution
# when interrupt is enabled in xie, but disabled via global mstatus.xIE
m.d.sync += finished.eq(wfi_resume)

m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret)
m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret | illegal_sret | illegal_sfencevma)

with Transaction().body(m, ready=instr_valid & finished):
m.d.sync += instr_valid.eq(0)
Expand All @@ -125,7 +171,13 @@ def _(arg):
with OneHotSwitch(m, instr_fn) as OneHotCase:
with OneHotCase(PrivilegedFn.Fn.MRET):
m.d.av_comb += ret_pc.eq(csr.m_mode.mepc.read(m).data)
# FENCE.I and WFI can't be compressed, so the next instruction is always pc+4
if self.fn.supervisor_enable:
with OneHotCase(PrivilegedFn.Fn.SRET):
m.d.av_comb += ret_pc.eq(csr.s_mode.sepc.read(m).data)
# SFENCE.VMA, FENCE.I and WFI can't be compressed, so next PC is always pc+4
if self.fn.supervisor_enable:
with OneHotCase(PrivilegedFn.Fn.SFENCEVMA):
m.d.av_comb += ret_pc.eq(instr_pc + 4)
with OneHotCase(PrivilegedFn.Fn.FENCEI):
m.d.av_comb += ret_pc.eq(instr_pc + 4)
with OneHotCase(PrivilegedFn.Fn.WFI):
Expand All @@ -145,10 +197,21 @@ def _(arg):
m.d.av_comb += instr[7:12].eq(0)
m.d.av_comb += instr[12:15].eq(Funct3.PRIV)
m.d.av_comb += instr[15:20].eq(0)
m.d.av_comb += instr[20:32].eq(Mux(instr_fn == PrivilegedFn.Fn.MRET, Funct12.WFI, Funct12.MRET))
log.error(
m, (instr_fn != PrivilegedFn.Fn.MRET) & (instr_fn != PrivilegedFn.Fn.WFI), "missing Funct12 case"
)
with m.Switch(instr_fn):
with m.Case(PrivilegedFn.Fn.MRET):
m.d.av_comb += instr[20:32].eq(Funct12.MRET)
with m.Case(PrivilegedFn.Fn.WFI):
m.d.av_comb += instr[20:32].eq(Funct12.WFI)
if self.fn.supervisor_enable:
with m.Case(PrivilegedFn.Fn.SRET):
m.d.av_comb += instr[20:32].eq(Funct12.SRET)
with m.Case(PrivilegedFn.Fn.SFENCEVMA):
imm_view = data.View(self.gen_params.get(PrivUnitLayouts).sfencevma_imm_layout, instr_imm)
m.d.av_comb += instr[15:20].eq(imm_view.rs1)
m.d.av_comb += instr[20:25].eq(imm_view.rs2)
m.d.av_comb += instr[25:32].eq(Funct7.SFENCEVMA)
with m.Default():
log.error(m, True, "missing Funct12 case")

self.exception_report(
m, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=ret_pc, rob_id=instr_rob, mtval=instr
Expand Down Expand Up @@ -183,7 +246,14 @@ def _(arg):

@dataclass(frozen=True)
class PrivilegedUnitComponent(FunctionalComponentParams):
decoder_manager: PrivilegedFn = PrivilegedFn()
_: KW_ONLY
supervisor_enable: bool = False
decoder_manager: PrivilegedFn = field(init=False)

def get_decoder_manager(self):
return PrivilegedFn(supervisor_enable=self.supervisor_enable)

def get_module(self, gen_params: GenParams) -> FuncUnit:
assert self.supervisor_enable == gen_params.supervisor_mode

return PrivilegedFuncUnit(gen_params, self.decoder_manager)
7 changes: 7 additions & 0 deletions coreblocks/interface/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
"ExceptionReportKey",
"CSRInstancesKey",
"AsyncInterruptInsertSignalKey",
"WaitForInterruptResumeKey",
"MretKey",
"SretKey",
"CoreStateKey",
"CSRListKey",
"FlushICacheKey",
Expand Down Expand Up @@ -92,6 +94,11 @@ class MretKey(SimpleKey[Method]):
pass


@dataclass(frozen=True)
class SretKey(SimpleKey[Method]):
pass


@dataclass(frozen=True)
class CoreStateKey(SimpleKey[Method]):
pass
Expand Down
Loading
Loading