Skip to content

Commit e761f27

Browse files
authored
feat: provide ctx.signal (#7878)
1 parent 69ac92c commit e761f27

File tree

20 files changed

+333
-42
lines changed

20 files changed

+333
-42
lines changed

docs/guide/test-context.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@ it('math is hard', ({ skip, mind }) => {
7979
})
8080
```
8181

82+
#### `context.signal` <Version>3.2.0</Version> {#context-signal}
83+
84+
An [`AbortSignal`](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) that can be aborted by Vitest. The signal is aborted in these situations:
85+
86+
- Test times out
87+
- User manually cancelled the test run with Ctrl+C
88+
- [`vitest.cancelCurrentRun`](/advanced/api/vitest#cancelcurrentrun) was called programmatically
89+
- Another test failed in parallel and the [`bail`](/config/#bail) flag is set
90+
91+
```ts
92+
it('stop request when test times out', async ({ signal }) => {
93+
await fetch('/resource', { signal })
94+
}, 2000)
95+
```
96+
8297
#### `onTestFailed`
8398

8499
The [`onTestFailed`](/api/#ontestfailed) hook bound to the current test. This API is useful if you are running tests concurrently and need to have a special handling only for this specific test.

packages/browser/src/client/tester/runner.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ export function createBrowserRunner(
6262
const currentFailures = 1 + previousFailures
6363

6464
if (currentFailures >= this.config.bail) {
65-
rpc().onCancel('test-failure')
66-
this.onCancel('test-failure')
65+
rpc().cancelCurrentRun('test-failure')
66+
this.cancel('test-failure')
6767
}
6868
}
6969
}
@@ -81,8 +81,8 @@ export function createBrowserRunner(
8181
}
8282
}
8383

84-
onCancel = (reason: CancelReason) => {
85-
super.onCancel?.(reason)
84+
cancel = (reason: CancelReason) => {
85+
super.cancel?.(reason)
8686
globalChannel.postMessage({ type: 'cancel', reason })
8787
}
8888

@@ -196,7 +196,7 @@ export async function initiateRunner(
196196
cachedRunner = runner
197197

198198
onCancel.then((reason) => {
199-
runner.onCancel?.(reason)
199+
runner.cancel?.(reason)
200200
})
201201

202202
const [diffOptions] = await Promise.all([

packages/browser/src/node/rpc.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ export function setupBrowserRpc(globalServer: ParentBrowserProject, defaultMocke
202202
const mod = globalServer.vite.moduleGraph.getModuleById(id)
203203
return mod?.transformResult?.map
204204
},
205-
onCancel(reason) {
205+
cancelCurrentRun(reason) {
206206
vitest.cancelCurrentRun(reason)
207207
},
208208
async resolveId(id, importer) {

packages/browser/src/node/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export interface WebSocketBrowserHandlers {
2121
onCollected: (method: TestExecutionMethod, files: RunnerTestFile[]) => Promise<void>
2222
onTaskUpdate: (method: TestExecutionMethod, packs: TaskResultPack[], events: TaskEventPack[]) => void
2323
onAfterSuiteRun: (meta: AfterSuiteRunMeta) => void
24-
onCancel: (reason: CancelReason) => void
24+
cancelCurrentRun: (reason: CancelReason) => void
2525
getCountOfFailedTests: () => number
2626
readSnapshotFile: (id: string) => Promise<string | null>
2727
saveSnapshotFile: (id: string, content: string) => Promise<void>

packages/runner/src/context.ts

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type {
55
SuiteCollector,
66
Test,
77
TestContext,
8+
WriteableTestContext,
89
} from './types/tasks'
910
import { getSafeTimers } from '@vitest/utils'
1011
import { PendingError } from './errors'
@@ -36,6 +37,7 @@ export function withTimeout<T extends (...args: any[]) => any>(
3637
timeout: number,
3738
isHook = false,
3839
stackTraceError?: Error,
40+
onTimeout?: (args: T extends (...args: infer A) => any ? A : never, error: Error) => void,
3941
): T {
4042
if (timeout <= 0 || timeout === Number.POSITIVE_INFINITY) {
4143
return fn
@@ -58,7 +60,9 @@ export function withTimeout<T extends (...args: any[]) => any>(
5860
timer.unref?.()
5961

6062
function rejectTimeoutError() {
61-
reject_(makeTimeoutError(isHook, timeout, stackTraceError))
63+
const error = makeTimeoutError(isHook, timeout, stackTraceError)
64+
onTimeout?.(args, error)
65+
reject_(error)
6266
}
6367

6468
function resolve(result: unknown) {
@@ -102,14 +106,35 @@ export function withTimeout<T extends (...args: any[]) => any>(
102106
}) as T
103107
}
104108

109+
const abortControllers = new WeakMap<TestContext, AbortController>()
110+
111+
export function abortIfTimeout([context]: [TestContext?], error: Error): void {
112+
if (context) {
113+
abortContextSignal(context, error)
114+
}
115+
}
116+
117+
export function abortContextSignal(context: TestContext, error: Error): void {
118+
const abortController = abortControllers.get(context)
119+
abortController?.abort(error)
120+
}
121+
105122
export function createTestContext(
106123
test: Test,
107124
runner: VitestRunner,
108125
): TestContext {
109126
const context = function () {
110127
throw new Error('done() callback is deprecated, use promise instead')
111-
} as unknown as TestContext
128+
} as unknown as WriteableTestContext
129+
130+
let abortController = abortControllers.get(context)
131+
132+
if (!abortController) {
133+
abortController = new AbortController()
134+
abortControllers.set(context, abortController)
135+
}
112136

137+
context.signal = abortController.signal
113138
context.task = test
114139

115140
context.skip = (condition?: boolean | string, note?: string): never => {
@@ -129,14 +154,26 @@ export function createTestContext(
129154
context.onTestFailed = (handler, timeout) => {
130155
test.onFailed ||= []
131156
test.onFailed.push(
132-
withTimeout(handler, timeout ?? runner.config.hookTimeout, true, new Error('STACK_TRACE_ERROR')),
157+
withTimeout(
158+
handler,
159+
timeout ?? runner.config.hookTimeout,
160+
true,
161+
new Error('STACK_TRACE_ERROR'),
162+
(_, error) => abortController.abort(error),
163+
),
133164
)
134165
}
135166

136167
context.onTestFinished = (handler, timeout) => {
137168
test.onFinished ||= []
138169
test.onFinished.push(
139-
withTimeout(handler, timeout ?? runner.config.hookTimeout, true, new Error('STACK_TRACE_ERROR')),
170+
withTimeout(
171+
handler,
172+
timeout ?? runner.config.hookTimeout,
173+
true,
174+
new Error('STACK_TRACE_ERROR'),
175+
(_, error) => abortController.abort(error),
176+
),
140177
)
141178
}
142179

packages/runner/src/errors.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { CancelReason } from './types/runner'
12
import type { TaskBase } from './types/tasks'
23

34
export class PendingError extends Error {
@@ -9,3 +10,12 @@ export class PendingError extends Error {
910
this.taskId = task.id
1011
}
1112
}
13+
14+
export class TestRunAbortError extends Error {
15+
public name = 'TestRunAbortError'
16+
public reason: CancelReason
17+
constructor(message: string, reason: CancelReason) {
18+
super(message)
19+
this.reason = reason
20+
}
21+
}

packages/runner/src/hooks.ts

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ import type {
77
OnTestFinishedHandler,
88
TaskHook,
99
TaskPopulated,
10+
TestContext,
1011
} from './types/tasks'
1112
import { assertTypes } from '@vitest/utils'
12-
import { withTimeout } from './context'
13+
import { abortContextSignal, abortIfTimeout, withTimeout } from './context'
1314
import { withFixtures } from './fixture'
1415
import { getCurrentSuite, getRunner } from './suite'
1516
import { getCurrentTest } from './test-state'
@@ -21,7 +22,7 @@ function getDefaultHookTimeout() {
2122
const CLEANUP_TIMEOUT_KEY = Symbol.for('VITEST_CLEANUP_TIMEOUT')
2223
const CLEANUP_STACK_TRACE_KEY = Symbol.for('VITEST_CLEANUP_STACK_TRACE')
2324

24-
export function getBeforeHookCleanupCallback(hook: Function, result: any): Function | undefined {
25+
export function getBeforeHookCleanupCallback(hook: Function, result: any, context?: TestContext): Function | undefined {
2526
if (typeof result === 'function') {
2627
const timeout
2728
= CLEANUP_TIMEOUT_KEY in hook && typeof hook[CLEANUP_TIMEOUT_KEY] === 'number'
@@ -31,7 +32,17 @@ export function getBeforeHookCleanupCallback(hook: Function, result: any): Funct
3132
= CLEANUP_STACK_TRACE_KEY in hook && hook[CLEANUP_STACK_TRACE_KEY] instanceof Error
3233
? hook[CLEANUP_STACK_TRACE_KEY]
3334
: undefined
34-
return withTimeout(result, timeout, true, stackTraceError)
35+
return withTimeout(
36+
result,
37+
timeout,
38+
true,
39+
stackTraceError,
40+
(_, error) => {
41+
if (context) {
42+
abortContextSignal(context, error)
43+
}
44+
},
45+
)
3546
}
3647
}
3748

@@ -136,6 +147,7 @@ export function beforeEach<ExtraContext = object>(
136147
timeout ?? getDefaultHookTimeout(),
137148
true,
138149
stackTraceError,
150+
abortIfTimeout,
139151
),
140152
{
141153
[CLEANUP_TIMEOUT_KEY]: timeout,
@@ -174,6 +186,7 @@ export function afterEach<ExtraContext = object>(
174186
timeout ?? getDefaultHookTimeout(),
175187
true,
176188
new Error('STACK_TRACE_ERROR'),
189+
abortIfTimeout,
177190
),
178191
)
179192
}
@@ -206,6 +219,7 @@ export const onTestFailed: TaskHook<OnTestFailedHandler> = createTestHook(
206219
timeout ?? getDefaultHookTimeout(),
207220
true,
208221
new Error('STACK_TRACE_ERROR'),
222+
abortIfTimeout,
209223
),
210224
)
211225
},
@@ -244,6 +258,7 @@ export const onTestFinished: TaskHook<OnTestFinishedHandler> = createTestHook(
244258
timeout ?? getDefaultHookTimeout(),
245259
true,
246260
new Error('STACK_TRACE_ERROR'),
261+
abortIfTimeout,
247262
),
248263
)
249264
},

packages/runner/src/run.ts

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ import type {
1515
TaskUpdateEvent,
1616
Test,
1717
TestContext,
18+
WriteableTestContext,
1819
} from './types/tasks'
1920
import { shuffle } from '@vitest/utils'
2021
import { processError } from '@vitest/utils/error'
2122
import { collectTests } from './collect'
22-
import { PendingError } from './errors'
23+
import { abortContextSignal } from './context'
24+
import { PendingError, TestRunAbortError } from './errors'
2325
import { callFixtureCleanup } from './fixture'
2426
import { getBeforeHookCleanupCallback } from './hooks'
2527
import { getFn, getHooks } from './map'
26-
import { setCurrentTest } from './test-state'
28+
import { addRunningTest, getRunningTests, setCurrentTest } from './test-state'
2729
import { limitConcurrency } from './utils/limit-concurrency'
2830
import { partitionSuiteChildren } from './utils/suite'
2931
import { hasFailed, hasTests } from './utils/tasks'
@@ -87,12 +89,14 @@ async function callTestHooks(
8789
return
8890
}
8991

92+
const context = test.context as WriteableTestContext
93+
9094
const onTestFailed = test.context.onTestFailed
9195
const onTestFinished = test.context.onTestFinished
92-
test.context.onTestFailed = () => {
96+
context.onTestFailed = () => {
9397
throw new Error(`Cannot call "onTestFailed" inside a test hook.`)
9498
}
95-
test.context.onTestFinished = () => {
99+
context.onTestFinished = () => {
96100
throw new Error(`Cannot call "onTestFinished" inside a test hook.`)
97101
}
98102

@@ -115,8 +119,8 @@ async function callTestHooks(
115119
}
116120
}
117121

118-
test.context.onTestFailed = onTestFailed
119-
test.context.onTestFinished = onTestFinished
122+
context.onTestFailed = onTestFailed
123+
context.onTestFinished = onTestFinished
120124
}
121125

122126
export async function callSuiteHook<T extends keyof SuiteHooks>(
@@ -145,7 +149,11 @@ export async function callSuiteHook<T extends keyof SuiteHooks>(
145149
}
146150

147151
async function runHook(hook: Function) {
148-
return getBeforeHookCleanupCallback(hook, await hook(...args))
152+
return getBeforeHookCleanupCallback(
153+
hook,
154+
await hook(...args),
155+
name === 'beforeEach' ? args[0] : undefined,
156+
)
149157
}
150158

151159
if (sequence === 'parallel') {
@@ -274,6 +282,7 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
274282
}
275283
updateTask('test-prepare', test, runner)
276284

285+
const cleanupRunningTest = addRunningTest(test)
277286
setCurrentTest(test)
278287

279288
const suite = test.suite || test.file
@@ -374,6 +383,7 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
374383
}
375384
updateTask('test-finished', test, runner)
376385
setCurrentTest(undefined)
386+
cleanupRunningTest()
377387
return
378388
}
379389

@@ -405,6 +415,7 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
405415
}
406416
}
407417

418+
cleanupRunningTest()
408419
setCurrentTest(undefined)
409420

410421
test.result.duration = now() - start
@@ -588,21 +599,38 @@ export async function runFiles(files: File[], runner: VitestRunner): Promise<voi
588599
}
589600

590601
export async function startTests(specs: string[] | FileSpecification[], runner: VitestRunner): Promise<File[]> {
591-
const paths = specs.map(f => typeof f === 'string' ? f : f.filepath)
592-
await runner.onBeforeCollect?.(paths)
602+
const cancel = runner.cancel?.bind(runner)
603+
// Ideally, we need to have an event listener for this, but only have a runner here.
604+
// Adding another onCancel felt wrong (maybe it needs to be refactored)
605+
runner.cancel = (reason) => {
606+
// We intentionally create only one error since there is only one test run that can be cancelled
607+
const error = new TestRunAbortError('The test run was aborted by the user.', reason)
608+
getRunningTests().forEach(test =>
609+
abortContextSignal(test.context, error),
610+
)
611+
return cancel?.(reason)
612+
}
593613

594-
const files = await collectTests(specs, runner)
614+
try {
615+
const paths = specs.map(f => typeof f === 'string' ? f : f.filepath)
616+
await runner.onBeforeCollect?.(paths)
595617

596-
await runner.onCollected?.(files)
597-
await runner.onBeforeRunFiles?.(files)
618+
const files = await collectTests(specs, runner)
598619

599-
await runFiles(files, runner)
620+
await runner.onCollected?.(files)
621+
await runner.onBeforeRunFiles?.(files)
600622

601-
await runner.onAfterRunFiles?.(files)
623+
await runFiles(files, runner)
602624

603-
await finishSendTasksUpdate(runner)
625+
await runner.onAfterRunFiles?.(files)
604626

605-
return files
627+
await finishSendTasksUpdate(runner)
628+
629+
return files
630+
}
631+
finally {
632+
runner.cancel = cancel
633+
}
606634
}
607635

608636
async function publicCollect(specs: string[] | FileSpecification[], runner: VitestRunner): Promise<File[]> {

0 commit comments

Comments
 (0)