Skip to content

Commit ba9fa20

Browse files
authored
feat: improve test function classification (#8235)
1 parent 7074d20 commit ba9fa20

File tree

7 files changed

+212
-126
lines changed

7 files changed

+212
-126
lines changed

crates/cli/src/utils/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ impl<T: AsRef<Path>> FoundryPathExt for T {
6767
}
6868

6969
/// Initializes a tracing Subscriber for logging
70-
#[allow(dead_code)]
7170
pub fn subscriber() {
7271
tracing_subscriber::Registry::default()
7372
.with(tracing_subscriber::EnvFilter::from_default_env())

crates/common/src/compile.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,16 @@ impl ProjectCompiler {
237237
for (name, artifact) in artifacts {
238238
let size = deployed_contract_size(artifact).unwrap_or_default();
239239

240-
let dev_functions =
241-
artifact.abi.as_ref().map(|abi| abi.functions()).into_iter().flatten().filter(
242-
|func| {
243-
func.name.is_test() ||
244-
func.name.eq("IS_TEST") ||
245-
func.name.eq("IS_SCRIPT")
246-
},
247-
);
248-
249-
let is_dev_contract = dev_functions.count() > 0;
240+
let is_dev_contract = artifact
241+
.abi
242+
.as_ref()
243+
.map(|abi| {
244+
abi.functions().any(|f| {
245+
f.test_function_kind().is_known() ||
246+
matches!(f.name.as_str(), "IS_TEST" | "IS_SCRIPT")
247+
})
248+
})
249+
.unwrap_or(false);
250250
size_report.contracts.insert(name, ContractInfo { size, is_dev_contract });
251251
}
252252

crates/common/src/traits.rs

Lines changed: 155 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use alloy_json_abi::Function;
44
use alloy_primitives::Bytes;
55
use alloy_sol_types::SolError;
6-
use std::path::Path;
6+
use std::{fmt, path::Path};
77

88
/// Test filter.
99
pub trait TestFilter: Send + Sync {
@@ -19,116 +19,208 @@ pub trait TestFilter: Send + Sync {
1919

2020
/// Extension trait for `Function`.
2121
pub trait TestFunctionExt {
22-
/// Returns whether this function should be executed as invariant test.
23-
fn is_invariant_test(&self) -> bool;
24-
25-
/// Returns whether this function should be executed as fuzz test.
26-
fn is_fuzz_test(&self) -> bool;
22+
/// Returns the kind of test function.
23+
fn test_function_kind(&self) -> TestFunctionKind {
24+
TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs())
25+
}
2726

28-
/// Returns whether this function is a test.
29-
fn is_test(&self) -> bool;
27+
/// Returns `true` if this function is a `setUp` function.
28+
fn is_setup(&self) -> bool {
29+
self.test_function_kind().is_setup()
30+
}
3031

31-
/// Returns whether this function is a test that should fail.
32-
fn is_test_fail(&self) -> bool;
32+
/// Returns `true` if this function is a unit, fuzz, or invariant test.
33+
fn is_any_test(&self) -> bool {
34+
self.test_function_kind().is_any_test()
35+
}
3336

34-
/// Returns whether this function is a `setUp` function.
35-
fn is_setup(&self) -> bool;
37+
/// Returns `true` if this function is a test that should fail.
38+
fn is_any_test_fail(&self) -> bool {
39+
self.test_function_kind().is_any_test_fail()
40+
}
3641

37-
/// Returns whether this function is `afterInvariant` function.
38-
fn is_after_invariant(&self) -> bool;
42+
/// Returns `true` if this function is a unit test.
43+
fn is_unit_test(&self) -> bool {
44+
matches!(self.test_function_kind(), TestFunctionKind::UnitTest { .. })
45+
}
3946

40-
/// Returns whether this function is a fixture function.
41-
fn is_fixture(&self) -> bool;
42-
}
47+
/// Returns `true` if this function is a fuzz test.
48+
fn is_fuzz_test(&self) -> bool {
49+
self.test_function_kind().is_fuzz_test()
50+
}
4351

44-
impl TestFunctionExt for Function {
52+
/// Returns `true` if this function is an invariant test.
4553
fn is_invariant_test(&self) -> bool {
46-
self.name.is_invariant_test()
54+
self.test_function_kind().is_invariant_test()
4755
}
4856

49-
fn is_fuzz_test(&self) -> bool {
50-
// test functions that have inputs are considered fuzz tests as those inputs will be fuzzed
51-
!self.inputs.is_empty()
57+
/// Returns `true` if this function is an `afterInvariant` function.
58+
fn is_after_invariant(&self) -> bool {
59+
self.test_function_kind().is_after_invariant()
5260
}
5361

54-
fn is_test(&self) -> bool {
55-
self.name.is_test()
62+
/// Returns `true` if this function is a `fixture` function.
63+
fn is_fixture(&self) -> bool {
64+
self.test_function_kind().is_fixture()
5665
}
5766

58-
fn is_test_fail(&self) -> bool {
59-
self.name.is_test_fail()
67+
#[doc(hidden)]
68+
fn tfe_as_str(&self) -> &str;
69+
#[doc(hidden)]
70+
fn tfe_has_inputs(&self) -> bool;
71+
}
72+
73+
impl TestFunctionExt for Function {
74+
fn tfe_as_str(&self) -> &str {
75+
self.name.as_str()
6076
}
6177

62-
fn is_setup(&self) -> bool {
63-
self.name.is_setup()
78+
fn tfe_has_inputs(&self) -> bool {
79+
!self.inputs.is_empty()
6480
}
81+
}
6582

66-
fn is_after_invariant(&self) -> bool {
67-
self.name.is_after_invariant()
83+
impl TestFunctionExt for String {
84+
fn tfe_as_str(&self) -> &str {
85+
self
6886
}
6987

70-
fn is_fixture(&self) -> bool {
71-
self.name.is_fixture()
88+
fn tfe_has_inputs(&self) -> bool {
89+
false
7290
}
7391
}
7492

75-
impl TestFunctionExt for String {
76-
fn is_invariant_test(&self) -> bool {
77-
self.as_str().is_invariant_test()
93+
impl TestFunctionExt for str {
94+
fn tfe_as_str(&self) -> &str {
95+
self
7896
}
7997

80-
fn is_fuzz_test(&self) -> bool {
81-
self.as_str().is_fuzz_test()
98+
fn tfe_has_inputs(&self) -> bool {
99+
false
82100
}
101+
}
102+
103+
/// Test function kind.
104+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
105+
pub enum TestFunctionKind {
106+
/// `setUp`.
107+
Setup,
108+
/// `test*`. `should_fail` is `true` for `testFail*`.
109+
UnitTest { should_fail: bool },
110+
/// `test*`, with arguments. `should_fail` is `true` for `testFail*`.
111+
FuzzTest { should_fail: bool },
112+
/// `invariant*` or `statefulFuzz*`.
113+
InvariantTest,
114+
/// `afterInvariant`.
115+
AfterInvariant,
116+
/// `fixture*`.
117+
Fixture,
118+
/// Unknown kind.
119+
Unknown,
120+
}
83121

84-
fn is_test(&self) -> bool {
85-
self.as_str().is_test()
122+
impl TestFunctionKind {
123+
/// Classify a function.
124+
#[inline]
125+
pub fn classify(name: &str, has_inputs: bool) -> Self {
126+
match () {
127+
_ if name.starts_with("test") => {
128+
let should_fail = name.starts_with("testFail");
129+
if has_inputs {
130+
Self::FuzzTest { should_fail }
131+
} else {
132+
Self::UnitTest { should_fail }
133+
}
134+
}
135+
_ if name.starts_with("invariant") || name.starts_with("statefulFuzz") => {
136+
Self::InvariantTest
137+
}
138+
_ if name.eq_ignore_ascii_case("setup") => Self::Setup,
139+
_ if name.eq_ignore_ascii_case("afterinvariant") => Self::AfterInvariant,
140+
_ if name.starts_with("fixture") => Self::Fixture,
141+
_ => Self::Unknown,
142+
}
86143
}
87144

88-
fn is_test_fail(&self) -> bool {
89-
self.as_str().is_test_fail()
145+
/// Returns the name of the function kind.
146+
pub const fn name(&self) -> &'static str {
147+
match self {
148+
Self::Setup => "setUp",
149+
Self::UnitTest { should_fail: false } => "test",
150+
Self::UnitTest { should_fail: true } => "testFail",
151+
Self::FuzzTest { should_fail: false } => "fuzz",
152+
Self::FuzzTest { should_fail: true } => "fuzz fail",
153+
Self::InvariantTest => "invariant",
154+
Self::AfterInvariant => "afterInvariant",
155+
Self::Fixture => "fixture",
156+
Self::Unknown => "unknown",
157+
}
90158
}
91159

92-
fn is_setup(&self) -> bool {
93-
self.as_str().is_setup()
160+
/// Returns `true` if this function is a `setUp` function.
161+
#[inline]
162+
pub const fn is_setup(&self) -> bool {
163+
matches!(self, Self::Setup)
94164
}
95165

96-
fn is_after_invariant(&self) -> bool {
97-
self.as_str().is_after_invariant()
166+
/// Returns `true` if this function is a unit, fuzz, or invariant test.
167+
#[inline]
168+
pub const fn is_any_test(&self) -> bool {
169+
matches!(self, Self::UnitTest { .. } | Self::FuzzTest { .. } | Self::InvariantTest)
98170
}
99171

100-
fn is_fixture(&self) -> bool {
101-
self.as_str().is_fixture()
172+
/// Returns `true` if this function is a test that should fail.
173+
#[inline]
174+
pub const fn is_any_test_fail(&self) -> bool {
175+
matches!(self, Self::UnitTest { should_fail: true } | Self::FuzzTest { should_fail: true })
102176
}
103-
}
104177

105-
impl TestFunctionExt for str {
106-
fn is_invariant_test(&self) -> bool {
107-
self.starts_with("invariant") || self.starts_with("statefulFuzz")
178+
/// Returns `true` if this function is a unit test.
179+
#[inline]
180+
pub fn is_unit_test(&self) -> bool {
181+
matches!(self, Self::UnitTest { .. })
108182
}
109183

110-
fn is_fuzz_test(&self) -> bool {
111-
unimplemented!("no naming convention for fuzz tests")
184+
/// Returns `true` if this function is a fuzz test.
185+
#[inline]
186+
pub const fn is_fuzz_test(&self) -> bool {
187+
matches!(self, Self::FuzzTest { .. })
112188
}
113189

114-
fn is_test(&self) -> bool {
115-
self.starts_with("test")
190+
/// Returns `true` if this function is an invariant test.
191+
#[inline]
192+
pub const fn is_invariant_test(&self) -> bool {
193+
matches!(self, Self::InvariantTest)
116194
}
117195

118-
fn is_test_fail(&self) -> bool {
119-
self.starts_with("testFail")
196+
/// Returns `true` if this function is an `afterInvariant` function.
197+
#[inline]
198+
pub const fn is_after_invariant(&self) -> bool {
199+
matches!(self, Self::AfterInvariant)
120200
}
121201

122-
fn is_setup(&self) -> bool {
123-
self.eq_ignore_ascii_case("setup")
202+
/// Returns `true` if this function is a `fixture` function.
203+
#[inline]
204+
pub const fn is_fixture(&self) -> bool {
205+
matches!(self, Self::Fixture)
124206
}
125207

126-
fn is_after_invariant(&self) -> bool {
127-
self.eq_ignore_ascii_case("afterinvariant")
208+
/// Returns `true` if this function kind is known.
209+
#[inline]
210+
pub const fn is_known(&self) -> bool {
211+
!matches!(self, Self::Unknown)
128212
}
129213

130-
fn is_fixture(&self) -> bool {
131-
self.starts_with("fixture")
214+
/// Returns `true` if this function kind is unknown.
215+
#[inline]
216+
pub const fn is_unknown(&self) -> bool {
217+
matches!(self, Self::Unknown)
218+
}
219+
}
220+
221+
impl fmt::Display for TestFunctionKind {
222+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223+
self.name().fmt(f)
132224
}
133225
}
134226

crates/evm/coverage/src/analysis.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ impl<'a> SourceAnalyzer<'a> {
493493

494494
let is_test = items.iter().any(|item| {
495495
if let CoverageItemKind::Function { name } = &item.kind {
496-
name.is_test()
496+
name.is_any_test()
497497
} else {
498498
false
499499
}

crates/forge/src/gas_report.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ impl GasReport {
106106
} else if let Some(DecodedCallData { signature, .. }) = decoded.func {
107107
let name = signature.split('(').next().unwrap();
108108
// ignore any test/setup functions
109-
let should_include = !(name.is_test() || name.is_invariant_test() || name.is_setup());
110-
if should_include {
109+
if !name.test_function_kind().is_known() {
111110
trace!(contract_name, signature, "adding gas info");
112111
let gas_info = contract_info
113112
.functions

crates/forge/src/multi_runner.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl MultiContractRunner {
102102
.iter()
103103
.filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
104104
.flat_map(|(_, TestContract { abi, .. })| abi.functions())
105-
.filter(|func| func.is_test() || func.is_invariant_test())
105+
.filter(|func| func.is_any_test())
106106
}
107107

108108
/// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
@@ -392,7 +392,7 @@ impl MultiContractRunnerBuilder {
392392

393393
// if it's a test, link it and add to deployable contracts
394394
if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true) &&
395-
abi.functions().any(|func| func.name.is_test() || func.name.is_invariant_test())
395+
abi.functions().any(|func| func.name.is_any_test())
396396
{
397397
let Some(bytecode) =
398398
contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
@@ -434,5 +434,5 @@ pub fn matches_contract(id: &ArtifactId, abi: &JsonAbi, filter: &dyn TestFilter)
434434

435435
/// Returns `true` if the function is a test function that matches the given filter.
436436
pub(crate) fn is_matching_test(func: &Function, filter: &dyn TestFilter) -> bool {
437-
(func.is_test() || func.is_invariant_test()) && filter.matches_test(&func.signature())
437+
func.is_any_test() && filter.matches_test(&func.signature())
438438
}

0 commit comments

Comments
 (0)