diff --git a/README.md b/README.md index 87838cea..022f511d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ this project. ## Features - Connections over HTTP or HTTPS -- Supports HTTP Basic Authentication +- Supports Basic and OAuth2 authentication types - Per-query user information for access control ## Requirements @@ -65,6 +65,17 @@ const data: QueryData[] = await iter .fold([], (row, acc) => [...acc, ...row]); ``` +### Using OAuth2 Authentication + +```typescript +const trino: Trino = Trino.create({ + server: 'http://localhost:8080', + catalog: 'tpcds', + schema: 'sf100000', + auth: new OAuth2Auth('token', 'clientId', 'clientSecret', 'refreshToken', 'tokenEndpoint'), +}); +``` + ## Examples More usage examples can be found in the diff --git a/package.json b/package.json index 4b3a806b..66a35b4e 100644 --- a/package.json +++ b/package.json @@ -48,7 +48,9 @@ }, "scripts": { "build": "tsc --project tsconfig.build.json", - "test:it": "jest --testPathPatterns tests/it", + "test:it": "jest --testPathPattern tests/it", + "test:unit": "jest --testPathPattern tests/unit", + "test": "jest", "test:lint": "eslint . --flag unstable_ts_config", "publish": "yarn build && yarn npm publish" } diff --git a/src/index.ts b/src/index.ts index 1928aaf3..498d076d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -35,6 +35,22 @@ export class BasicAuth implements Auth { constructor(readonly username: string, readonly password?: string) {} } +export class OAuth2Auth implements Auth { + readonly type: AuthType = 'oauth2'; + constructor( + readonly token: string, + readonly clientId?: string, + readonly clientSecret?: string, + readonly refreshToken?: string, + readonly tokenEndpoint?: string, + readonly scopes?: string[], + readonly tokenType?: string, + readonly expiresIn?: number, + readonly redirectUri?: string, + readonly grantType?: string + ) {} +} + export type Session = {[key: string]: string}; export type ExtraCredential = {[key: string]: string}; @@ -196,14 +212,52 @@ class Client { ...(options.extraHeaders ?? {}), }; - if (options.auth && options.auth.type === 'basic') { - const basic: BasicAuth = options.auth; - clientConfig.auth = { - username: basic.username, - password: basic.password ?? '', - }; - - headers[TRINO_USER_HEADER] = basic.username; + if (options.auth) { + switch (options.auth.type) { + case 'basic': { + const basic: BasicAuth = options.auth; + clientConfig.auth = { + username: basic.username, + password: basic.password ?? '', + }; + headers[TRINO_USER_HEADER] = basic.username; + break; + } + case 'oauth2': { + const oauth2: OAuth2Auth = options.auth; + headers['Authorization'] = `Bearer ${oauth2.token}`; + if (oauth2.clientId) { + headers['Client-Id'] = oauth2.clientId; + } + if (oauth2.clientSecret) { + headers['Client-Secret'] = oauth2.clientSecret; + } + if (oauth2.refreshToken) { + headers['Refresh-Token'] = oauth2.refreshToken; + } + if (oauth2.tokenEndpoint) { + headers['Token-Endpoint'] = oauth2.tokenEndpoint; + } + if (oauth2.scopes) { + headers['Scopes'] = oauth2.scopes.join(' '); + } + if (oauth2.tokenType) { + headers['Token-Type'] = oauth2.tokenType; + } + if (oauth2.expiresIn) { + headers['Expires-In'] = oauth2.expiresIn.toString(); + } + if (oauth2.redirectUri) { + headers['Redirect-Uri'] = oauth2.redirectUri; + } + if (oauth2.grantType) { + headers['Grant-Type'] = oauth2.grantType; + } + break; + } + default: + throw new Error(`Unsupported auth type: ${options.auth.type}`); + } } clientConfig.headers = cleanHeaders(headers); diff --git a/tests/it/client.spec.ts b/tests/it/client.spec.ts index 1510d032..cef7b1c1 100644 --- a/tests/it/client.spec.ts +++ b/tests/it/client.spec.ts @@ -1,4 +1,4 @@ -import {BasicAuth, QueryData, Trino} from '../../src'; +import {BasicAuth, OAuth2Auth, QueryData, Trino} from '../../src'; const allCustomerQuery = 'select * from customer'; const limit = 1; @@ -175,4 +175,19 @@ describe('trino', () => { ]); expect(sales).toHaveLength(limit); }); + + test.concurrent('oauth2 auth', async () => { + const trino = Trino.create({ + catalog: 'tpcds', + schema: 'sf100000', + auth: new OAuth2Auth('token'), + }); + + const iter = await trino.query(singleCustomerQuery); + const data = await iter + .map(r => r.data ?? []) + .fold([], (row, acc) => [...acc, ...row]); + + expect(data).toHaveLength(limit); + }); }); diff --git a/tests/unit/auth.spec.ts b/tests/unit/auth.spec.ts new file mode 100644 index 00000000..5272721e --- /dev/null +++ b/tests/unit/auth.spec.ts @@ -0,0 +1,81 @@ +import {BasicAuth, OAuth2Auth} from '../../src'; + +describe('Auth Classes', () => { + describe('BasicAuth', () => { + test('should create BasicAuth with username only', () => { + const auth = new BasicAuth('testuser'); + expect(auth.username).toBe('testuser'); + expect(auth.password).toBeUndefined(); + expect(auth.type).toBe('basic'); + }); + + test('should create BasicAuth with username and password', () => { + const auth = new BasicAuth('testuser', 'testpass'); + expect(auth.username).toBe('testuser'); + expect(auth.password).toBe('testpass'); + expect(auth.type).toBe('basic'); + }); + }); + + describe('OAuth2Auth', () => { + test('should create OAuth2Auth with token only', () => { + const auth = new OAuth2Auth('test-token'); + expect(auth.token).toBe('test-token'); + expect(auth.type).toBe('oauth2'); + expect(auth.clientId).toBeUndefined(); + expect(auth.clientSecret).toBeUndefined(); + expect(auth.refreshToken).toBeUndefined(); + expect(auth.tokenEndpoint).toBeUndefined(); + expect(auth.scopes).toBeUndefined(); + expect(auth.tokenType).toBeUndefined(); + expect(auth.expiresIn).toBeUndefined(); + expect(auth.redirectUri).toBeUndefined(); + expect(auth.grantType).toBeUndefined(); + }); + + test('should create OAuth2Auth with all optional parameters', () => { + const auth = new OAuth2Auth( + 'test-token', + 'client-id', + 'client-secret', + 'refresh-token', + 'https://example.com/oauth2/token', + ['read', 'write'], + 'Bearer', + 3600, + 'https://example.com/callback', + 'authorization_code' + ); + + expect(auth.token).toBe('test-token'); + expect(auth.clientId).toBe('client-id'); + expect(auth.clientSecret).toBe('client-secret'); + expect(auth.refreshToken).toBe('refresh-token'); + expect(auth.tokenEndpoint).toBe('https://example.com/oauth2/token'); + expect(auth.scopes).toEqual(['read', 'write']); + expect(auth.tokenType).toBe('Bearer'); + expect(auth.expiresIn).toBe(3600); + expect(auth.redirectUri).toBe('https://example.com/callback'); + expect(auth.grantType).toBe('authorization_code'); + expect(auth.type).toBe('oauth2'); + }); + + test('should create OAuth2Auth with some optional parameters', () => { + const auth = new OAuth2Auth( + 'test-token', + 'client-id', + undefined, + 'refresh-token', + undefined, + ['read'] + ); + + expect(auth.token).toBe('test-token'); + expect(auth.clientId).toBe('client-id'); + expect(auth.clientSecret).toBeUndefined(); + expect(auth.refreshToken).toBe('refresh-token'); + expect(auth.tokenEndpoint).toBeUndefined(); + expect(auth.scopes).toEqual(['read']); + }); + }); +}); \ No newline at end of file