Skip to content

Commit c604e74

Browse files
fix(NODE-6051): only provide expected allowed keys to libmongocrypt after fetching aws kms credentials (#4057)
1 parent 0e3d6ea commit c604e74

File tree

7 files changed

+110
-60
lines changed

7 files changed

+110
-60
lines changed
Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
1-
import { getAwsCredentialProvider } from '../../deps';
1+
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
22
import { type KMSProviders } from '.';
33

44
/**
55
* @internal
66
*/
77
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
8-
const credentialProvider = getAwsCredentialProvider();
8+
const credentialProvider = new AWSSDKCredentialProvider();
99

10-
if ('kModuleError' in credentialProvider) {
11-
return kmsProviders;
12-
}
10+
// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
11+
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings
12+
// and let libmongocrypt error if we're unable to fetch the required keys.
13+
const {
14+
SecretAccessKey = '',
15+
AccessKeyId = '',
16+
Token
17+
} = await credentialProvider.getCredentials();
18+
const aws: NonNullable<KMSProviders['aws']> = {
19+
secretAccessKey: SecretAccessKey,
20+
accessKeyId: AccessKeyId
21+
};
22+
// the AWS session token is only required for temporary credentials so only attach it to the
23+
// result if it's present in the response from the aws sdk
24+
Token != null && (aws.sessionToken = Token);
1325

14-
const { fromNodeProviderChain } = credentialProvider;
15-
const provider = fromNodeProviderChain();
16-
// The state machine is the only place calling this so it will
17-
// catch if there is a rejection here.
18-
const aws = await provider();
1926
return { ...kmsProviders, aws };
2027
}

test/integration/auth/mongodb_aws.test.ts

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,26 @@ import * as http from 'http';
55
import { performance } from 'perf_hooks';
66
import * as sinon from 'sinon';
77

8+
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
9+
import { refreshKMSCredentials } from '../../../src/client-side-encryption/providers';
810
import {
911
AWSTemporaryCredentialProvider,
1012
MongoAWSError,
1113
type MongoClient,
1214
MongoDBAWS,
1315
MongoMissingCredentialsError,
14-
MongoServerError
16+
MongoServerError,
17+
setDifference
1518
} from '../../mongodb';
1619

17-
function awsSdk() {
18-
try {
19-
return require('@aws-sdk/credential-providers');
20-
} catch {
21-
return null;
22-
}
23-
}
20+
const isMongoDBAWSAuthEnvironment = (process.env.MONGODB_URI ?? '').includes('MONGODB-AWS');
2421

2522
describe('MONGODB-AWS', function () {
2623
let awsSdkPresent;
2724
let client: MongoClient;
2825

2926
beforeEach(function () {
30-
const MONGODB_URI = process.env.MONGODB_URI;
31-
if (!MONGODB_URI || MONGODB_URI.indexOf('MONGODB-AWS') === -1) {
27+
if (!isMongoDBAWSAuthEnvironment) {
3228
this.currentTest.skipReason = 'requires MONGODB_URI to contain MONGODB-AWS auth mechanism';
3329
return this.skip();
3430
}
@@ -39,7 +35,7 @@ describe('MONGODB-AWS', function () {
3935
`Always inform the AWS tests if they run with or without the SDK (MONGODB_AWS_SDK=${MONGODB_AWS_SDK})`
4036
).to.include(MONGODB_AWS_SDK);
4137

42-
awsSdkPresent = !!awsSdk();
38+
awsSdkPresent = AWSTemporaryCredentialProvider.isAWSSDKInstalled;
4339
expect(
4440
awsSdkPresent,
4541
MONGODB_AWS_SDK === 'true'
@@ -244,8 +240,10 @@ describe('MONGODB-AWS', function () {
244240

245241
const envCheck = () => {
246242
const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env;
247-
credentialProvider = awsSdk();
248-
return AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || credentialProvider == null;
243+
return (
244+
AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 ||
245+
!AWSTemporaryCredentialProvider.isAWSSDKInstalled
246+
);
249247
};
250248

251249
beforeEach(function () {
@@ -255,6 +253,9 @@ describe('MONGODB-AWS', function () {
255253
return this.skip();
256254
}
257255

256+
// @ts-expect-error We intentionally access a protected variable.
257+
credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
258+
258259
storedEnv = process.env;
259260
if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) {
260261
delete process.env.AWS_STS_REGIONAL_ENDPOINTS;
@@ -324,3 +325,49 @@ describe('MONGODB-AWS', function () {
324325
}
325326
});
326327
});
328+
329+
describe('AWS KMS Credential Fetching', function () {
330+
context('when the AWS SDK is not installed', function () {
331+
beforeEach(function () {
332+
this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment
333+
? 'Test must run in an AWS auth testing environment'
334+
: AWSTemporaryCredentialProvider.isAWSSDKInstalled
335+
? 'This test must run in an environment where the AWS SDK is not installed.'
336+
: undefined;
337+
this.currentTest?.skipReason && this.skip();
338+
});
339+
it('fetching AWS KMS credentials throws an error', async function () {
340+
const error = await refreshKMSCredentials({ aws: {} }).catch(e => e);
341+
expect(error).to.be.instanceOf(MongoAWSError);
342+
});
343+
});
344+
345+
context('when the AWS SDK is installed', function () {
346+
beforeEach(function () {
347+
this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment
348+
? 'Test must run in an AWS auth testing environment'
349+
: !AWSTemporaryCredentialProvider.isAWSSDKInstalled
350+
? 'This test must run in an environment where the AWS SDK is installed.'
351+
: undefined;
352+
this.currentTest?.skipReason && this.skip();
353+
});
354+
it('KMS credentials are successfully fetched.', async function () {
355+
const { aws } = await refreshKMSCredentials({ aws: {} });
356+
357+
expect(aws).to.have.property('accessKeyId');
358+
expect(aws).to.have.property('secretAccessKey');
359+
});
360+
361+
it('does not return any extra keys for the `aws` credential provider', async function () {
362+
const { aws } = await refreshKMSCredentials({ aws: {} });
363+
364+
const keys = new Set(Object.keys(aws ?? {}));
365+
const allowedKeys = ['accessKeyId', 'secretAccessKey', 'sessionToken'];
366+
367+
expect(
368+
Array.from(setDifference(keys, allowedKeys)),
369+
'received an unexpected key in the response refreshing KMS credentials'
370+
).to.deep.equal([]);
371+
});
372+
});
373+
});

test/unit/assorted/polling_srv_records_for_mongos_discovery.prose.test.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { expect } from 'chai';
22
import * as dns from 'dns';
33
import { once } from 'events';
4-
import { coerce } from 'semver';
4+
import { satisfies } from 'semver';
55
import * as sinon from 'sinon';
66

77
import {
@@ -51,11 +51,9 @@ describe('Polling Srv Records for Mongos Discovery', () => {
5151
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
5252
const test = this.currentTest!;
5353

54-
const { major } = coerce(process.version);
55-
test.skipReason =
56-
major === 18 || major === 20
57-
? 'TODO(NODE-5666): fix failing unit tests on Node18'
58-
: undefined;
54+
test.skipReason = satisfies(process.version, '>=18.0.0')
55+
? `TODO(NODE-5666): fix failing unit tests on Node18 (Running with Nodejs ${process.version})`
56+
: undefined;
5957

6058
if (test.skipReason) this.skip();
6159
});

test/unit/client-side-encryption/providers/credentialsProvider.test.ts

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import {
2020
} from '../../../../src/client-side-encryption/providers/azure';
2121
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
2222
import * as utils from '../../../../src/client-side-encryption/providers/utils';
23+
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
24+
import { AWSSDKCredentialProvider } from '../../../../src/cmap/auth/aws_temporary_credentials';
2325
import * as requirements from '../requirements.helper';
2426

2527
const originalAccessKeyId = process.env.AWS_ACCESS_KEY_ID;
@@ -154,25 +156,25 @@ describe('#refreshKMSCredentials', function () {
154156
});
155157
});
156158

157-
context('when the sdk is not installed', function () {
158-
const kmsProviders = {
159-
local: {
160-
key: Buffer.alloc(96)
161-
},
162-
aws: {}
163-
};
164-
165-
before(function () {
166-
if (requirements.credentialProvidersInstalled.aws && this.currentTest) {
167-
this.currentTest.skipReason = 'Credentials will be loaded when sdk present';
168-
this.currentTest.skip();
169-
return;
170-
}
159+
context('when the AWS SDK returns unknown fields', function () {
160+
beforeEach(() => {
161+
sinon.stub(AWSSDKCredentialProvider.prototype, 'getCredentials').resolves({
162+
Token: 'example',
163+
SecretAccessKey: 'example',
164+
AccessKeyId: 'example',
165+
Expiration: new Date()
166+
});
171167
});
172-
173-
it('does not refresh credentials', async function () {
174-
const providers = await refreshKMSCredentials(kmsProviders);
175-
expect(providers).to.deep.equal(kmsProviders);
168+
afterEach(() => sinon.restore());
169+
it('only returns fields libmongocrypt expects', async function () {
170+
const credentials = await refreshKMSCredentials({ aws: {} });
171+
expect(credentials).to.deep.equal({
172+
aws: {
173+
accessKeyId: accessKey,
174+
secretAccessKey: secretKey,
175+
sessionToken: sessionToken
176+
}
177+
});
176178
});
177179
});
178180
});

test/unit/connection_string.spec.test.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { coerce } from 'semver';
1+
import { satisfies } from 'semver';
22

33
import { loadSpecTests } from '../spec';
44
import { executeUriValidationTest } from '../tools/uri_spec_runner';
@@ -15,14 +15,13 @@ describe('Connection String spec tests', function () {
1515
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
1616
const test = this.currentTest!;
1717

18-
const { major } = coerce(process.version);
1918
const skippedTests = [
2019
'Invalid port (zero) with IP literal',
2120
'Invalid port (zero) with hostname'
2221
];
2322
test.skipReason =
24-
major === 20 && skippedTests.includes(test.title)
25-
? 'TODO(NODE-5666): fix failing unit tests on Node18'
23+
satisfies(process.version, '>=20.0.0') && skippedTests.includes(test.title)
24+
? 'TODO(NODE-5666): fix failing unit tests on Node20+'
2625
: undefined;
2726

2827
if (test.skipReason) this.skip();

test/unit/sdam/monitor.test.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { once } from 'node:events';
22
import * as net from 'node:net';
33

44
import { expect } from 'chai';
5-
import { coerce } from 'semver';
5+
import { satisfies } from 'semver';
66
import * as sinon from 'sinon';
77
import { setTimeout } from 'timers';
88
import { setTimeout as setTimeoutPromise } from 'timers/promises';
@@ -57,7 +57,6 @@ describe('monitoring', function () {
5757
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
5858
const test = this.currentTest!;
5959

60-
const { major } = coerce(process.version);
6160
const failingTests = [
6261
'should connect and issue an initial server check',
6362
'should ignore attempts to connect when not already closed',
@@ -67,7 +66,7 @@ describe('monitoring', function () {
6766
'correctly returns the mean of the heartbeat durations'
6867
];
6968
test.skipReason =
70-
(major === 18 || major === 20) && failingTests.includes(test.title)
69+
satisfies(process.version, '>=18.0.0') && failingTests.includes(test.title)
7170
? 'TODO(NODE-5666): fix failing unit tests on Node18'
7271
: undefined;
7372

test/unit/sdam/topology.test.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { expect } from 'chai';
22
import { once } from 'events';
33
import * as net from 'net';
44
import { type AddressInfo } from 'net';
5-
import { coerce, type SemVer } from 'semver';
5+
import { satisfies } from 'semver';
66
import * as sinon from 'sinon';
77
import { clearTimeout } from 'timers';
88

@@ -284,11 +284,9 @@ describe('Topology (unit)', function () {
284284
it('should encounter a server selection timeout on garbled server responses', function () {
285285
const test = this.test;
286286

287-
const { major } = coerce(process.version) as SemVer;
288-
test.skipReason =
289-
major === 18 || major === 20
290-
? 'TODO(NODE-5666): fix failing unit tests on Node18'
291-
: undefined;
287+
test.skipReason = satisfies(process.version, '>=18.0.0')
288+
? 'TODO(NODE-5666): fix failing unit tests on Node18'
289+
: undefined;
292290

293291
if (test.skipReason) this.skip();
294292

0 commit comments

Comments
 (0)