diff --git a/src/execution/__tests__/subscribe-test.ts b/src/execution/__tests__/subscribe-test.ts index d943ef4006..03e3da2839 100644 --- a/src/execution/__tests__/subscribe-test.ts +++ b/src/execution/__tests__/subscribe-test.ts @@ -5,6 +5,8 @@ import { expectJSON } from '../../__testUtils__/expectJSON'; import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick'; import { isAsyncIterable } from '../../jsutils/isAsyncIterable'; +import { isPromise } from '../../jsutils/isPromise'; +import type { PromiseOrValue } from '../../jsutils/PromiseOrValue'; import { parse } from '../../language/parser'; @@ -123,28 +125,49 @@ function createSubscription(pubsub: SimplePubSub) { return subscribe({ schema: emailSchema, document, rootValue: data }); } -async function expectPromise(promise: Promise) { - let caughtError: Error; - - try { - /* c8 ignore next 2 */ - await promise; - expect.fail('promise should have thrown but did not'); - } catch (error) { - caughtError = error; - } +// TODO: consider adding this method to testUtils (with tests) +function expectPromise(maybePromise: unknown) { + assert(isPromise(maybePromise)); return { - toReject() { - expect(caughtError).to.be.an.instanceOf(Error); + toResolve() { + return maybePromise; }, - toRejectWith(message: string) { + async toRejectWith(message: string) { + let caughtError: Error; + + try { + /* c8 ignore next 2 */ + await maybePromise; + expect.fail('promise should have thrown but did not'); + } catch (error) { + caughtError = error; + } + expect(caughtError).to.be.an.instanceOf(Error); expect(caughtError).to.have.property('message', message); }, }; } +// TODO: consider adding this method to testUtils (with tests) +function expectEqualPromisesOrValues( + value1: PromiseOrValue, + value2: PromiseOrValue, +): PromiseOrValue { + if (isPromise(value1)) { + assert(isPromise(value2)); + return Promise.all([value1, value2]).then((resolved) => { + expectJSON(resolved[1]).toDeepEqual(resolved[0]); + return resolved[0]; + }); + } + + assert(!isPromise(value2)); + expectJSON(value2).toDeepEqual(value1); + return value1; +} + const DummyQueryType = new GraphQLObjectType({ name: 'Query', fields: { @@ -152,9 +175,9 @@ const DummyQueryType = new GraphQLObjectType({ }, }); -async function subscribeWithBadFn( +function subscribeWithBadFn( subscribeFn: () => unknown, -): Promise { +): PromiseOrValue> { const schema = new GraphQLSchema({ query: DummyQueryType, subscription: new GraphQLObjectType({ @@ -165,13 +188,11 @@ async function subscribeWithBadFn( }), }); const document = parse('subscription { foo }'); - const result = await subscribe({ schema, document }); - assert(!isAsyncIterable(result)); - expectJSON(await createSourceEventStream(schema, document)).toDeepEqual( - result, + return expectEqualPromisesOrValues( + subscribe({ schema, document }), + createSourceEventStream(schema, document), ); - return result; } /* eslint-disable @typescript-eslint/require-await */ @@ -193,7 +214,7 @@ describe('Subscription Initialization Phase', () => { yield { foo: 'FooValue' }; } - const subscription = await subscribe({ + const subscription = subscribe({ schema, document: parse('subscription { foo }'), rootValue: { foo: fooGenerator }, @@ -229,7 +250,7 @@ describe('Subscription Initialization Phase', () => { }), }); - const subscription = await subscribe({ + const subscription = subscribe({ schema, document: parse('subscription { foo }'), }); @@ -267,10 +288,13 @@ describe('Subscription Initialization Phase', () => { }), }); - const subscription = await subscribe({ + const promise = subscribe({ schema, document: parse('subscription { foo }'), }); + assert(isPromise(promise)); + + const subscription = await promise; assert(isAsyncIterable(subscription)); expect(await subscription.next()).to.deep.equal({ @@ -299,7 +323,7 @@ describe('Subscription Initialization Phase', () => { yield { foo: 'FooValue' }; } - const subscription = await subscribe({ + const subscription = subscribe({ schema, document: parse('subscription { foo }'), rootValue: { customFoo: fooGenerator }, @@ -349,7 +373,7 @@ describe('Subscription Initialization Phase', () => { }), }); - const subscription = await subscribe({ + const subscription = subscribe({ schema, document: parse('subscription { foo bar }'), }); @@ -379,31 +403,29 @@ describe('Subscription Initialization Phase', () => { }); // @ts-expect-error (schema must not be null) - (await expectPromise(subscribe({ schema: null, document }))).toRejectWith( + expect(() => subscribe({ schema: null, document })).to.throw( 'Expected null to be a GraphQL schema.', ); // @ts-expect-error - (await expectPromise(subscribe({ document }))).toRejectWith( + expect(() => subscribe({ document })).to.throw( 'Expected undefined to be a GraphQL schema.', ); // @ts-expect-error (document must not be null) - (await expectPromise(subscribe({ schema, document: null }))).toRejectWith( + expect(() => subscribe({ schema, document: null })).to.throw( 'Must provide document.', ); // @ts-expect-error - (await expectPromise(subscribe({ schema }))).toRejectWith( - 'Must provide document.', - ); + expect(() => subscribe({ schema })).to.throw('Must provide document.'); }); it('resolves to an error if schema does not support subscriptions', async () => { const schema = new GraphQLSchema({ query: DummyQueryType }); const document = parse('subscription { unknownField }'); - const result = await subscribe({ schema, document }); + const result = subscribe({ schema, document }); expectJSON(result).toDeepEqual({ errors: [ { @@ -427,7 +449,7 @@ describe('Subscription Initialization Phase', () => { }); const document = parse('subscription { unknownField }'); - const result = await subscribe({ schema, document }); + const result = subscribe({ schema, document }); expectJSON(result).toDeepEqual({ errors: [ { @@ -450,11 +472,11 @@ describe('Subscription Initialization Phase', () => { }); // @ts-expect-error - (await expectPromise(subscribe({ schema, document: {} }))).toReject(); + expect(() => subscribe({ schema, document: {} })).to.throw(); }); it('throws an error if subscribe does not return an iterator', async () => { - expectJSON(await subscribeWithBadFn(() => 'test')).toDeepEqual({ + const expectedResult = { errors: [ { message: @@ -463,7 +485,15 @@ describe('Subscription Initialization Phase', () => { path: ['foo'], }, ], - }); + }; + + expectJSON(subscribeWithBadFn(() => 'test')).toDeepEqual(expectedResult); + + expectJSON( + await expectPromise( + subscribeWithBadFn(() => Promise.resolve('test')), + ).toResolve(), + ).toDeepEqual(expectedResult); }); it('resolves to an error for subscription resolver errors', async () => { @@ -479,24 +509,28 @@ describe('Subscription Initialization Phase', () => { expectJSON( // Returning an error - await subscribeWithBadFn(() => new Error('test error')), + subscribeWithBadFn(() => new Error('test error')), ).toDeepEqual(expectedResult); expectJSON( // Throwing an error - await subscribeWithBadFn(() => { + subscribeWithBadFn(() => { throw new Error('test error'); }), ).toDeepEqual(expectedResult); expectJSON( // Resolving to an error - await subscribeWithBadFn(() => Promise.resolve(new Error('test error'))), + await expectPromise( + subscribeWithBadFn(() => Promise.resolve(new Error('test error'))), + ).toResolve(), ).toDeepEqual(expectedResult); expectJSON( // Rejecting with an error - await subscribeWithBadFn(() => Promise.reject(new Error('test error'))), + await expectPromise( + subscribeWithBadFn(() => Promise.reject(new Error('test error'))), + ).toResolve(), ).toDeepEqual(expectedResult); }); @@ -523,7 +557,7 @@ describe('Subscription Initialization Phase', () => { // If we receive variables that cannot be coerced correctly, subscribe() will // resolve to an ExecutionResult that contains an informative error description. - const result = await subscribe({ schema, document, variableValues }); + const result = subscribe({ schema, document, variableValues }); expectJSON(result).toDeepEqual({ errors: [ { @@ -541,10 +575,10 @@ describe('Subscription Publish Phase', () => { it('produces a payload for multiple subscribe in same subscription', async () => { const pubsub = new SimplePubSub(); - const subscription = await createSubscription(pubsub); + const subscription = createSubscription(pubsub); assert(isAsyncIterable(subscription)); - const secondSubscription = await createSubscription(pubsub); + const secondSubscription = createSubscription(pubsub); assert(isAsyncIterable(secondSubscription)); const payload1 = subscription.next(); @@ -583,7 +617,7 @@ describe('Subscription Publish Phase', () => { it('produces a payload per subscription event', async () => { const pubsub = new SimplePubSub(); - const subscription = await createSubscription(pubsub); + const subscription = createSubscription(pubsub); assert(isAsyncIterable(subscription)); // Wait for the next subscription payload. @@ -672,7 +706,7 @@ describe('Subscription Publish Phase', () => { it('produces a payload when there are multiple events', async () => { const pubsub = new SimplePubSub(); - const subscription = await createSubscription(pubsub); + const subscription = createSubscription(pubsub); assert(isAsyncIterable(subscription)); let payload = subscription.next(); @@ -738,7 +772,7 @@ describe('Subscription Publish Phase', () => { it('should not trigger when subscription is already done', async () => { const pubsub = new SimplePubSub(); - const subscription = await createSubscription(pubsub); + const subscription = createSubscription(pubsub); assert(isAsyncIterable(subscription)); let payload = subscription.next(); @@ -792,7 +826,7 @@ describe('Subscription Publish Phase', () => { it('should not trigger when subscription is thrown', async () => { const pubsub = new SimplePubSub(); - const subscription = await createSubscription(pubsub); + const subscription = createSubscription(pubsub); assert(isAsyncIterable(subscription)); let payload = subscription.next(); @@ -845,7 +879,7 @@ describe('Subscription Publish Phase', () => { it('event order is correct for multiple publishes', async () => { const pubsub = new SimplePubSub(); - const subscription = await createSubscription(pubsub); + const subscription = createSubscription(pubsub); assert(isAsyncIterable(subscription)); let payload = subscription.next(); @@ -936,7 +970,7 @@ describe('Subscription Publish Phase', () => { }); const document = parse('subscription { newMessage }'); - const subscription = await subscribe({ schema, document }); + const subscription = subscribe({ schema, document }); assert(isAsyncIterable(subscription)); expect(await subscription.next()).to.deep.equal({ @@ -997,7 +1031,7 @@ describe('Subscription Publish Phase', () => { }); const document = parse('subscription { newMessage }'); - const subscription = await subscribe({ schema, document }); + const subscription = subscribe({ schema, document }); assert(isAsyncIterable(subscription)); expect(await subscription.next()).to.deep.equal({ @@ -1007,7 +1041,7 @@ describe('Subscription Publish Phase', () => { }, }); - (await expectPromise(subscription.next())).toRejectWith('test error'); + await expectPromise(subscription.next()).toRejectWith('test error'); expect(await subscription.next()).to.deep.equal({ done: true, diff --git a/src/execution/subscribe.ts b/src/execution/subscribe.ts index e54949830c..9ff7cd2112 100644 --- a/src/execution/subscribe.ts +++ b/src/execution/subscribe.ts @@ -1,7 +1,9 @@ import { inspect } from '../jsutils/inspect'; import { isAsyncIterable } from '../jsutils/isAsyncIterable'; +import { isPromise } from '../jsutils/isPromise'; import type { Maybe } from '../jsutils/Maybe'; import { addPath, pathToArray } from '../jsutils/Path'; +import type { PromiseOrValue } from '../jsutils/PromiseOrValue'; import { GraphQLError } from '../error/GraphQLError'; import { locatedError } from '../error/locatedError'; @@ -47,9 +49,11 @@ import { getArgumentValues } from './values'; * * Accepts either an object with named arguments, or individual arguments. */ -export async function subscribe( +export function subscribe( args: ExecutionArgs, -): Promise | ExecutionResult> { +): PromiseOrValue< + AsyncGenerator | ExecutionResult +> { const { schema, document, @@ -61,7 +65,7 @@ export async function subscribe( subscribeFieldResolver, } = args; - const resultOrStream = await createSourceEventStream( + const resultOrStream = createSourceEventStream( schema, document, rootValue, @@ -71,6 +75,42 @@ export async function subscribe( subscribeFieldResolver, ); + if (isPromise(resultOrStream)) { + return resultOrStream.then((resolvedResultOrStream) => + mapSourceToResponse( + schema, + document, + resolvedResultOrStream, + contextValue, + variableValues, + operationName, + fieldResolver, + ), + ); + } + + return mapSourceToResponse( + schema, + document, + resultOrStream, + contextValue, + variableValues, + operationName, + fieldResolver, + ); +} + +function mapSourceToResponse( + schema: GraphQLSchema, + document: DocumentNode, + resultOrStream: ExecutionResult | AsyncIterable, + contextValue?: unknown, + variableValues?: Maybe<{ readonly [variable: string]: unknown }>, + operationName?: Maybe, + fieldResolver?: Maybe>, +): PromiseOrValue< + AsyncGenerator | ExecutionResult +> { if (!isAsyncIterable(resultOrStream)) { return resultOrStream; } @@ -81,7 +121,7 @@ export async function subscribe( // the GraphQL specification. The `execute` function provides the // "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the // "ExecuteQuery" algorithm, for which `execute` is also used. - const mapSourceToResponse = (payload: unknown) => + return mapAsyncIterator(resultOrStream, (payload: unknown) => execute({ schema, document, @@ -90,10 +130,8 @@ export async function subscribe( variableValues, operationName, fieldResolver, - }); - - // Map every source value to a ExecutionResult value as described above. - return mapAsyncIterator(resultOrStream, mapSourceToResponse); + }), + ); } /** @@ -124,7 +162,7 @@ export async function subscribe( * or otherwise separating these two steps. For more on this, see the * "Supporting Subscriptions at Scale" information in the GraphQL specification. */ -export async function createSourceEventStream( +export function createSourceEventStream( schema: GraphQLSchema, document: DocumentNode, rootValue?: unknown, @@ -132,7 +170,7 @@ export async function createSourceEventStream( variableValues?: Maybe<{ readonly [variable: string]: unknown }>, operationName?: Maybe, subscribeFieldResolver?: Maybe>, -): Promise | ExecutionResult> { +): PromiseOrValue | ExecutionResult> { // If arguments are missing or incorrectly typed, this is an internal // developer mistake which should throw an early error. assertValidExecutionArguments(schema, document, variableValues); @@ -155,7 +193,10 @@ export async function createSourceEventStream( } try { - const eventStream = await executeSubscription(exeContext); + const eventStream = executeSubscription(exeContext); + if (isPromise(eventStream)) { + return eventStream.then(undefined, (error) => ({ errors: [error] })); + } return eventStream; } catch (error) { @@ -163,9 +204,9 @@ export async function createSourceEventStream( } } -async function executeSubscription( +function executeSubscription( exeContext: ExecutionContext, -): Promise> { +): PromiseOrValue> { const { schema, fragments, operation, variableValues, rootValue } = exeContext; @@ -220,22 +261,32 @@ async function executeSubscription( // Call the `subscribe()` resolver or the default resolver to produce an // AsyncIterable yielding raw payloads. const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver; - const eventStream = await resolveFn(rootValue, args, contextValue, info); + const result = resolveFn(rootValue, args, contextValue, info); - if (eventStream instanceof Error) { - throw eventStream; + if (isPromise(result)) { + return result.then(assertEventStream).then(undefined, (error) => { + throw locatedError(error, fieldNodes, pathToArray(path)); + }); } - // Assert field returned an event stream, otherwise yield an error. - if (!isAsyncIterable(eventStream)) { - throw new GraphQLError( - 'Subscription field must return Async Iterable. ' + - `Received: ${inspect(eventStream)}.`, - ); - } - - return eventStream; + return assertEventStream(result); } catch (error) { throw locatedError(error, fieldNodes, pathToArray(path)); } } + +function assertEventStream(result: unknown): AsyncIterable { + if (result instanceof Error) { + throw result; + } + + // Assert field returned an event stream, otherwise yield an error. + if (!isAsyncIterable(result)) { + throw new GraphQLError( + 'Subscription field must return Async Iterable. ' + + `Received: ${inspect(result)}.`, + ); + } + + return result; +}