1+ import { AST_NODE_TYPES , TSESTree } from '@typescript-eslint/utils'
12import { createEslintRule , getAccessorValue } from '../utils'
2- import { parseVitestFnCall } from '../utils/parse-vitest-fn-call'
3+ import {
4+ ParsedExpectVitestFnCall ,
5+ parseVitestFnCall ,
6+ } from '../utils/parse-vitest-fn-call'
7+ import { SourceCode } from '@typescript-eslint/utils/ts-eslint'
38
49type MESSAGE_IDS = 'preferCalledExactlyOnceWith'
510export const RULE_NAME = 'prefer-called-exactly-once-with'
611type Options = [ ]
712
13+ const MATCHERS_TO_COMBINE = [
14+ 'toHaveBeenCalledOnce' ,
15+ 'toHaveBeenCalledWith' ,
16+ ] as const
17+
18+ type CombinedMatcher = ( typeof MATCHERS_TO_COMBINE ) [ number ]
19+
20+ type MatcherReference = {
21+ matcherName : CombinedMatcher
22+ callExpression : TSESTree . CallExpression
23+ }
24+
25+ const hasMatchersToCombine = ( target : string ) : target is CombinedMatcher =>
26+ MATCHERS_TO_COMBINE . some ( ( matcher ) => matcher === target )
27+
28+ const getExpectText = (
29+ expression : TSESTree . CallExpression ,
30+ source : Readonly < SourceCode > ,
31+ ) => {
32+ if ( expression . callee . type !== AST_NODE_TYPES . MemberExpression ) return null
33+
34+ const { range } = expression . callee . object
35+ return source . text . slice ( range [ 0 ] , range [ 1 ] )
36+ }
37+
38+ const getArgumentsText = (
39+ callExpression : TSESTree . CallExpression ,
40+ source : Readonly < SourceCode > ,
41+ ) => callExpression . arguments . map ( ( arg ) => source . getText ( arg ) ) . join ( ', ' )
42+
43+ const getValidExpectCall = (
44+ vitestFnCall : ReturnType < typeof parseVitestFnCall > ,
45+ ) : ParsedExpectVitestFnCall | null => {
46+ if ( vitestFnCall ?. type !== 'expect' ) return null
47+ if (
48+ vitestFnCall . modifiers . some (
49+ ( modifier ) => getAccessorValue ( modifier ) === 'not' ,
50+ )
51+ )
52+ return null
53+
54+ return vitestFnCall
55+ }
56+
57+ const getMatcherName = ( vitestFnCall : ReturnType < typeof parseVitestFnCall > ) => {
58+ const validExpectCall = getValidExpectCall ( vitestFnCall )
59+ return validExpectCall ? getAccessorValue ( validExpectCall . matcher ) : null
60+ }
61+
62+ const getMemberProperty = ( expression : TSESTree . CallExpression ) =>
63+ expression . callee . type === AST_NODE_TYPES . MemberExpression
64+ ? expression . callee . property
65+ : null
66+
867export default createEslintRule < Options , MESSAGE_IDS > ( {
968 name : RULE_NAME ,
1069 meta : {
@@ -22,36 +81,100 @@ export default createEslintRule<Options, MESSAGE_IDS>({
2281 } ,
2382 defaultOptions : [ ] ,
2483 create ( context ) {
25- return {
26- CallExpression ( node ) {
27- const vitestFnCall = parseVitestFnCall ( node , context )
84+ const { sourceCode } = context
2885
29- if ( vitestFnCall ?. type !== 'expect' ) return
86+ const getCallExpressions = (
87+ body : TSESTree . Statement [ ] ,
88+ ) : TSESTree . CallExpression [ ] =>
89+ body
90+ . filter ( ( node ) => node . type === AST_NODE_TYPES . ExpressionStatement )
91+ . flatMap ( ( node ) =>
92+ node . expression . type === AST_NODE_TYPES . CallExpression
93+ ? node . expression
94+ : [ ] ,
95+ )
96+
97+ const checkBlockBody = ( body : TSESTree . Statement [ ] ) => {
98+ const callExpressions = getCallExpressions ( body )
99+ const expectMatcherMap = new Map < string , Readonly < MatcherReference > [ ] > ( )
30100
31- if (
32- vitestFnCall . modifiers . some (
33- ( node ) => getAccessorValue ( node ) === 'not' ,
34- )
101+ for ( const callExpression of callExpressions ) {
102+ const matcherName = getMatcherName (
103+ parseVitestFnCall ( callExpression , context ) ,
35104 )
36- return
37-
38- const { matcher } = vitestFnCall
39- const matcherName = getAccessorValue ( matcher )
40-
41- if (
42- [ 'toHaveBeenCalledOnce' , 'toHaveBeenCalledWith' ] . includes ( matcherName )
43- ) {
44- context . report ( {
45- data : {
46- matcherName,
47- } ,
48- messageId : 'preferCalledExactlyOnceWith' ,
49- node : matcher ,
50- fix : ( fixer ) => [
51- fixer . replaceText ( matcher , `toHaveBeenCalledExactlyOnceWith` ) ,
52- ] ,
53- } )
54- }
105+ const expectedText = getExpectText ( callExpression , sourceCode )
106+ if ( ! matcherName || ! hasMatchersToCombine ( matcherName ) || ! expectedText )
107+ continue
108+
109+ const existingNodes = expectMatcherMap . get ( expectedText ) ?? [ ]
110+ const newTargetNodes = [
111+ ...existingNodes ,
112+ { matcherName, callExpression } ,
113+ ] as const satisfies MatcherReference [ ]
114+ expectMatcherMap . set ( expectedText , newTargetNodes )
115+ }
116+
117+ for ( const [
118+ expectedText ,
119+ matcherReferences ,
120+ ] of expectMatcherMap . entries ( ) ) {
121+ if ( matcherReferences . length !== 2 ) continue
122+
123+ const targetArgNode = matcherReferences . find (
124+ ( reference ) => reference . matcherName === 'toHaveBeenCalledWith' ,
125+ )
126+ if ( ! targetArgNode ) continue
127+
128+ const argsText = getArgumentsText (
129+ targetArgNode . callExpression ,
130+ sourceCode ,
131+ )
132+
133+ const [ firstMatcherReference , secondMatcherReference ] =
134+ matcherReferences
135+ const targetNode = getMemberProperty (
136+ secondMatcherReference . callExpression ,
137+ )
138+ if ( ! targetNode ) continue
139+
140+ const { callExpression : firstCallExpression } = firstMatcherReference
141+ const { callExpression : secondCallExpression , matcherName } =
142+ secondMatcherReference
143+
144+ context . report ( {
145+ messageId : 'preferCalledExactlyOnceWith' ,
146+ node : targetNode ,
147+ data : { matcherName } ,
148+ fix ( fixer ) {
149+ const indentation = sourceCode . text . slice (
150+ firstCallExpression . parent . range [ 0 ] ,
151+ firstCallExpression . range [ 0 ] ,
152+ )
153+ const replacement = `${ indentation } ${ expectedText } .toHaveBeenCalledExactlyOnceWith(${ argsText } )`
154+
155+ const lineStart = sourceCode . getIndexFromLoc ( {
156+ line : secondCallExpression . parent . loc . start . line ,
157+ column : 0 ,
158+ } )
159+ const lineEnd = sourceCode . getIndexFromLoc ( {
160+ line : secondCallExpression . parent . loc . end . line + 1 ,
161+ column : 0 ,
162+ } )
163+ return [
164+ fixer . replaceText ( firstCallExpression , replacement ) ,
165+ fixer . removeRange ( [ lineStart , lineEnd ] ) ,
166+ ]
167+ } ,
168+ } )
169+ }
170+ }
171+
172+ return {
173+ Program ( node ) {
174+ checkBlockBody ( node . body )
175+ } ,
176+ BlockStatement ( node ) {
177+ checkBlockBody ( node . body )
55178 } ,
56179 }
57180 } ,
0 commit comments