diff --git a/src/index.d.ts b/src/index.d.ts index 6c61e32b48..6c10e3bec1 100755 --- a/src/index.d.ts +++ b/src/index.d.ts @@ -5235,9 +5235,7 @@ declare namespace admin.machineLearning { displayName?: string; tags?: string[]; - tfLiteModel?: {gcsTFLiteUri: string;}; - - toJSON(forUpload?: boolean): object; + tfliteModel?: {gcsTfliteUri: string;}; } /** @@ -5247,8 +5245,8 @@ declare namespace admin.machineLearning { readonly modelId: string; readonly displayName: string; readonly tags?: string[]; - readonly createTime: number; - readonly updateTime: number; + readonly createTime: string; + readonly updateTime: string; readonly validationError?: string; readonly published: boolean; readonly etag: string; diff --git a/src/machine-learning/machine-learning-api-client.ts b/src/machine-learning/machine-learning-api-client.ts index 8331924371..295014fc01 100644 --- a/src/machine-learning/machine-learning-api-client.ts +++ b/src/machine-learning/machine-learning-api-client.ts @@ -52,6 +52,13 @@ export interface ModelResponse extends ModelContent { readonly modelHash?: string; } +export interface OperationResponse { + readonly name?: string; + readonly done: boolean; + readonly error?: StatusErrorResponse; + readonly response?: ModelResponse; +} + /** * Class that facilitates sending requests to the Firebase ML backend API. @@ -73,6 +80,24 @@ export class MachineLearningApiClient { this.httpClient = new AuthorizedHttpClient(app); } + public createModel(model: ModelContent): Promise { + if (!validator.isNonNullObject(model) || + !validator.isNonEmptyString(model.displayName)) { + const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model content.'); + return Promise.reject(err); + } + return this.getUrl() + .then((url) => { + const request: HttpRequestConfig = { + method: 'POST', + url: `${url}/models`, + data: model, + }; + return this.sendRequest(request); + }); + } + + public getModel(modelId: string): Promise { return Promise.resolve() .then(() => { diff --git a/src/machine-learning/machine-learning-utils.ts b/src/machine-learning/machine-learning-utils.ts index 0b4f84a64f..d0b881621a 100644 --- a/src/machine-learning/machine-learning-utils.ts +++ b/src/machine-learning/machine-learning-utils.ts @@ -25,9 +25,39 @@ export type MachineLearningErrorCode = | 'not-found' | 'resource-exhausted' | 'service-unavailable' - | 'unknown-error'; + | 'unknown-error' + | 'cancelled' + | 'deadline-exceeded' + | 'permission-denied' + | 'failed-precondition' + | 'aborted' + | 'out-of-range' + | 'data-loss' + | 'unauthenticated'; export class FirebaseMachineLearningError extends PrefixedFirebaseError { + public static fromOperationError(code: number, message: string): FirebaseMachineLearningError { + switch (code) { + case 1: return new FirebaseMachineLearningError('cancelled', message); + case 2: return new FirebaseMachineLearningError('unknown-error', message); + case 3: return new FirebaseMachineLearningError('invalid-argument', message); + case 4: return new FirebaseMachineLearningError('deadline-exceeded', message); + case 5: return new FirebaseMachineLearningError('not-found', message); + case 6: return new FirebaseMachineLearningError('already-exists', message); + case 7: return new FirebaseMachineLearningError('permission-denied', message); + case 8: return new FirebaseMachineLearningError('resource-exhausted', message); + case 9: return new FirebaseMachineLearningError('failed-precondition', message); + case 10: return new FirebaseMachineLearningError('aborted', message); + case 11: return new FirebaseMachineLearningError('out-of-range', message); + case 13: return new FirebaseMachineLearningError('internal-error', message); + case 14: return new FirebaseMachineLearningError('service-unavailable', message); + case 15: return new FirebaseMachineLearningError('data-loss', message); + case 16: return new FirebaseMachineLearningError('unauthenticated', message); + default: + return new FirebaseMachineLearningError('unknown-error', message); + } + } + constructor(code: MachineLearningErrorCode, message: string) { super('machine-learning', code, message); } diff --git a/src/machine-learning/machine-learning.ts b/src/machine-learning/machine-learning.ts index cb4ae0cf6a..c9e7b8b515 100644 --- a/src/machine-learning/machine-learning.ts +++ b/src/machine-learning/machine-learning.ts @@ -14,16 +14,14 @@ * limitations under the License. */ - import {FirebaseApp} from '../firebase-app'; import {FirebaseServiceInterface, FirebaseServiceInternalsInterface} from '../firebase-service'; -import {MachineLearningApiClient, ModelResponse} from './machine-learning-api-client'; +import {MachineLearningApiClient, ModelResponse, OperationResponse, ModelContent} from './machine-learning-api-client'; import {FirebaseError} from '../utils/error'; import * as validator from '../utils/validator'; import {FirebaseMachineLearningError} from './machine-learning-utils'; - -// const ML_HOST = 'mlkit.googleapis.com'; +import { deepCopy } from '../utils/deep-copy'; /** * Internals of an ML instance. @@ -97,7 +95,9 @@ export class MachineLearning implements FirebaseServiceInterface { * @return {Promise} A Promise fulfilled with the created model. */ public createModel(model: ModelOptions): Promise { - throw new Error('NotImplemented'); + return this.convertOptionstoContent(model, true) + .then((modelContent) => this.client.createModel(modelContent)) + .then((operation) => handleOperation(operation)); } /** @@ -170,10 +170,53 @@ export class MachineLearning implements FirebaseServiceInterface { public deleteModel(modelId: string): Promise { return this.client.deleteModel(modelId); } + + private convertOptionstoContent(options: ModelOptions, forUpload?: boolean): Promise { + const modelContent = deepCopy(options); + + if (forUpload && modelContent.tfliteModel?.gcsTfliteUri) { + return this.signUrl(modelContent.tfliteModel.gcsTfliteUri) + .then ((uri: string) => { + modelContent.tfliteModel!.gcsTfliteUri = uri; + return modelContent; + }) + .catch((err: Error) => { + throw new FirebaseMachineLearningError( + 'internal-error', + `Error during signing upload url: ${err.message}`); + }) as Promise; + } + + return Promise.resolve(modelContent) as Promise; + } + + private signUrl(unsignedUrl: string): Promise { + const MINUTES_IN_MILLIS = 60 * 1000; + const URL_VALID_DURATION = 10 * MINUTES_IN_MILLIS; + + const gcsRegex = /^gs:\/\/([a-z0-9_.-]{3,63})\/(.+)$/; + const matches = gcsRegex.exec(unsignedUrl); + if (!matches) { + throw new FirebaseMachineLearningError( + 'invalid-argument', + `Invalid unsigned url: ${unsignedUrl}`); + } + const bucketName = matches[1]; + const blobName = matches[2]; + const bucket = this.appInternal.storage().bucket(bucketName); + const blob = bucket.file(blobName); + return blob.getSignedUrl({ + action: 'read', + expires: Date.now() + URL_VALID_DURATION, + }).then((x) => { + return x[0]; + }); + } } + /** - * A Firebase ML Model output object + * A Firebase ML Model output object. */ export class Model { public readonly modelId: string; @@ -196,7 +239,7 @@ export class Model { !validator.isNonEmptyString(model.displayName) || !validator.isNonEmptyString(model.etag)) { throw new FirebaseMachineLearningError( - 'invalid-argument', + 'invalid-server-response', `Invalid Model response: ${JSON.stringify(model)}`); } @@ -252,13 +295,27 @@ export class ModelOptions { public displayName?: string; public tags?: string[]; - public tfliteModel?: { gcsTFLiteUri: string; }; - - protected toJSON(forUpload?: boolean): object { - throw new Error('NotImplemented'); - } + public tfliteModel?: { gcsTfliteUri: string; }; } + function extractModelId(resourceName: string): string { return resourceName.split('/').pop()!; } + + +function handleOperation(op: OperationResponse): Model { + // Backend currently does not return operations that are not done. + if (op.done) { + // Done operations must have either a response or an error. + if (op.response) { + return new Model(op.response); + } else if (op.error) { + throw FirebaseMachineLearningError.fromOperationError( + op.error.code, op.error.message); + } + } + throw new FirebaseMachineLearningError( + 'invalid-server-response', + `Invalid Operation response: ${JSON.stringify(op)}`); +} diff --git a/test/integration/machine-learning.spec.ts b/test/integration/machine-learning.spec.ts index 1da8a69729..cb56438ce7 100644 --- a/test/integration/machine-learning.spec.ts +++ b/test/integration/machine-learning.spec.ts @@ -14,9 +14,119 @@ * limitations under the License. */ + +import path = require('path'); +import * as chai from 'chai'; import * as admin from '../../lib/index'; +import {Bucket} from '@google-cloud/storage'; + +const expect = chai.expect; describe('admin.machineLearning', () => { + + const modelsToDelete: string[] = []; + + function scheduleForDelete(model: admin.machineLearning.Model) { + modelsToDelete.push(model.modelId); + } + + function unscheduleForDelete(model: admin.machineLearning.Model) { + modelsToDelete.splice(modelsToDelete.indexOf(model.modelId), 1); + } + + function deleteTempModels(): Promise { + const promises: Array> = []; + modelsToDelete.forEach((modelId) => { + promises.push(admin.machineLearning().deleteModel(modelId)); + }); + modelsToDelete.splice(0, modelsToDelete.length); // Clear out the array. + return Promise.all(promises); + } + + function createTemporaryModel(options?: admin.machineLearning.ModelOptions) + : Promise { + let modelOptions: admin.machineLearning.ModelOptions = { + displayName: 'nodejs_integration_temp_model', + }; + if (options) { + modelOptions = options; + } + return admin.machineLearning().createModel(modelOptions) + .then((model) => { + scheduleForDelete(model); + return model; + }); + } + + function uploadModelToGcs(localFileName: string, gcsFileName: string): Promise { + const bucket: Bucket = admin.storage().bucket(); + const tfliteFileName = path.join(__dirname, `../resources/${localFileName}`); + return bucket.upload(tfliteFileName, {destination: gcsFileName}) + .then(() => { + return `gs://${bucket.name}/${gcsFileName}`; + }); + } + + afterEach(() => { + return deleteTempModels(); + }); + + describe('createModel()', () => { + it('creates a new Model without ModelFormat', () => { + const modelOptions: admin.machineLearning.ModelOptions = { + displayName: 'node-integration-test-create-1', + tags: ['tag123', 'tag345']}; + return admin.machineLearning().createModel(modelOptions) + .then((model) => { + scheduleForDelete(model); + verifyModel(model, modelOptions); + }); + }); + + it('creates a new Model with valid ModelFormat', () => { + const modelOptions: admin.machineLearning.ModelOptions = { + displayName: 'node-integration-test-create-2', + tags: ['tag234', 'tag456'], + tfliteModel: {gcsTfliteUri: 'this will be replaced below'}, + }; + return uploadModelToGcs('model1.tflite', 'valid_model.tflite') + .then((fileName: string) => { + modelOptions.tfliteModel!.gcsTfliteUri = fileName; + return admin.machineLearning().createModel(modelOptions) + .then((model) => { + scheduleForDelete(model); + verifyModel(model, modelOptions); + }); + }); + }); + + it('creates a new Model with invalid ModelFormat', () => { + // Upload a file to default gcs bucket + const modelOptions: admin.machineLearning.ModelOptions = { + displayName: 'node-integration-test-create-3', + tags: ['tag234', 'tag456'], + tfliteModel: {gcsTfliteUri: 'this will be replaced below'}, + }; + return uploadModelToGcs('invalid_model.tflite', 'invalid_model.tflite') + .then((fileName: string) => { + modelOptions.tfliteModel!.gcsTfliteUri = fileName; + return admin.machineLearning().createModel(modelOptions) + .then((model) => { + scheduleForDelete(model); + verifyModel(model, modelOptions); + }); + }); + }); + + it ('rejects with invalid-argument when modelOptions are invalid', () => { + const modelOptions: admin.machineLearning.ModelOptions = { + displayName: 'Invalid Name#*^!', + }; + return admin.machineLearning().createModel(modelOptions) + .should.eventually.be.rejected.and.have.property('code', 'machine-learning/invalid-argument'); + }); + }); + describe('getModel()', () => { it('rejects with not-found when the Model does not exist', () => { const nonExistingName = '00000000'; @@ -30,6 +140,16 @@ describe('admin.machineLearning', () => { .should.eventually.be.rejected.and.have.property( 'code', 'machine-learning/invalid-argument'); }); + + it('resolves with existing Model', () => { + return createTemporaryModel() + .then((expectedModel) => + admin.machineLearning().getModel(expectedModel.modelId) + .then((actualModel) => { + expect(actualModel).to.deep.equal(expectedModel); + }), + ); + }); }); describe('deleteModel()', () => { @@ -45,5 +165,47 @@ describe('admin.machineLearning', () => { .should.eventually.be.rejected.and.have.property( 'code', 'machine-learning/invalid-argument'); }); + + it('deletes existing Model', () => { + return createTemporaryModel().then((model) => { + return admin.machineLearning().deleteModel(model.modelId) + .then(() => { + return admin.machineLearning().getModel(model.modelId) + .should.eventually.be.rejected.and.have.property('code', 'machine-learning/not-found'); + }) + .then(() => { + unscheduleForDelete(model); // Already deleted. + }); + }); + }); }); + + function verifyModel(model: admin.machineLearning.Model, expectedOptions: admin.machineLearning.ModelOptions) { + expect(model.displayName).to.equal(expectedOptions.displayName); + expect(model.createTime).to.not.be.empty; + expect(model.updateTime).to.not.be.empty; + expect(model.etag).to.not.be.empty; + if (expectedOptions.tags) { + expect(model.tags).to.deep.equal(expectedOptions.tags); + } else { + expect(model.tags).to.be.empty; + } + if (expectedOptions.tfliteModel) { + verifyTfliteModel(model, expectedOptions.tfliteModel.gcsTfliteUri); + } else { + expect(model.validationError).to.equal('No model file has been uploaded.'); + } + expect(model.locked).to.be.false; + } }); + +function verifyTfliteModel(model: admin.machineLearning.Model, expectedGcsTfliteUri: string) { + expect(model.tfliteModel!.gcsTfliteUri).to.equal(expectedGcsTfliteUri); + if (expectedGcsTfliteUri.endsWith('invalid_model.tflite')) { + expect(model.modelHash).to.be.empty; + expect(model.validationError).to.equal('Invalid flatbuffer format'); + } else { + expect(model.modelHash).to.not.be.empty; + expect(model.validationError).to.be.empty; + } +} diff --git a/test/resources/invalid_model.tflite b/test/resources/invalid_model.tflite new file mode 100644 index 0000000000..d8482f4362 --- /dev/null +++ b/test/resources/invalid_model.tflite @@ -0,0 +1 @@ +This is not a tflite file. diff --git a/test/resources/model1.tflite b/test/resources/model1.tflite new file mode 100644 index 0000000000..c4b71b7a22 Binary files /dev/null and b/test/resources/model1.tflite differ diff --git a/test/unit/machine-learning/machine-learning-api-client.spec.ts b/test/unit/machine-learning/machine-learning-api-client.spec.ts index a3b3935fc6..b4348b5b7c 100644 --- a/test/unit/machine-learning/machine-learning-api-client.spec.ts +++ b/test/unit/machine-learning/machine-learning-api-client.spec.ts @@ -19,7 +19,7 @@ import * as _ from 'lodash'; import * as chai from 'chai'; import * as sinon from 'sinon'; -import { MachineLearningApiClient } from '../../../src/machine-learning/machine-learning-api-client'; +import { MachineLearningApiClient, ModelContent } from '../../../src/machine-learning/machine-learning-api-client'; import { FirebaseMachineLearningError } from '../../../src/machine-learning/machine-learning-utils'; import { HttpClient } from '../../../src/utils/api-request'; import * as utils from '../utils'; @@ -78,6 +78,117 @@ describe('MachineLearningApiClient', () => { }); }); + describe('createModel', () => { + const NAME_ONLY_CONTENT: ModelContent = {displayName: 'name1'}; + const MODEL_RESPONSE = { + name: 'projects/test-project/models/1234567', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag123', + modelHash: 'modelHash123', + displayName: 'model_1', + tags: ['tag_1', 'tag_2'], + state: {published: true}, + tfliteModel: { + gcsTfliteUri: 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite', + sizeBytes: 16900988, + }, + }; + const STATUS_ERROR_RESPONSE = { + code: 3, + message: 'Invalid Argument message', + }; + const OPERATION_SUCCESS_RESPONSE = { + done: true, + response: MODEL_RESPONSE, + }; + const OPERATION_ERROR_RESPONSE = { + done: true, + error: STATUS_ERROR_RESPONSE, + }; + + const invalidContent: any[] = [null, undefined, {}, { tags: []}]; + invalidContent.forEach((content) => { + it(`should reject when called with: ${JSON.stringify(content)}`, () => { + return apiClient.createModel(content) + .should.eventually.be.rejected.and.have.property( + 'message', 'Invalid model content.'); + }); + }); + + it('should reject when project id is not available', () => { + return clientWithoutProjectId.createModel(NAME_ONLY_CONTENT) + .should.eventually.be.rejectedWith(noProjectId); + }); + + it('should throw when an error response is received', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(utils.errorFrom(ERROR_RESPONSE, 404)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError('not-found', 'Requested entity not found'); + return apiClient.createModel(NAME_ONLY_CONTENT) + .should.eventually.be.rejected.and.deep.equal(expected); + }); + + it('should resolve with the created resource on success', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(OPERATION_SUCCESS_RESPONSE)); + stubs.push(stub); + return apiClient.createModel(NAME_ONLY_CONTENT) + .then((resp) => { + expect(resp.done).to.be.true; + expect(resp.name).to.be.empty; + expect(resp.response).to.deep.equal(MODEL_RESPONSE); + }); + }); + + it('should resolve with error when the operation fails', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(OPERATION_ERROR_RESPONSE)); + stubs.push(stub); + return apiClient.createModel(NAME_ONLY_CONTENT) + .then((resp) => { + expect(resp.done).to.be.true; + expect(resp.name).to.be.empty; + expect(resp.error).to.deep.equal(STATUS_ERROR_RESPONSE); + }); + }); + + it('should reject with unknown-error when error code is not present', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(utils.errorFrom({}, 404)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError('unknown-error', 'Unknown server error: {}'); + return apiClient.createModel(NAME_ONLY_CONTENT) + .should.eventually.be.rejected.and.deep.equal(expected); + }); + + it('should reject with unknown-error for non-json response', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(utils.errorFrom('not json', 404)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError( + 'unknown-error', 'Unexpected response with status: 404 and body: not json'); + return apiClient.createModel(NAME_ONLY_CONTENT) + .should.eventually.be.rejected.and.deep.equal(expected); + }); + + it('should reject with when failed with a FirebaseAppError', () => { + const expected = new FirebaseAppError('network-error', 'socket hang up'); + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(expected); + stubs.push(stub); + return apiClient.createModel(NAME_ONLY_CONTENT) + .should.eventually.be.rejected.and.deep.equal(expected); + }); + }); + describe('getModel', () => { const INVALID_NAMES: any[] = [null, undefined, '', 1, true, {}, []]; INVALID_NAMES.forEach((invalidName) => { @@ -146,7 +257,7 @@ describe('MachineLearningApiClient', () => { .should.eventually.be.rejected.and.deep.equal(expected); }); - it('should reject when rejected with a FirebaseAppError', () => { + it('should reject when failed with a FirebaseAppError', () => { const expected = new FirebaseAppError('network-error', 'socket hang up'); const stub = sinon .stub(HttpClient.prototype, 'send') diff --git a/test/unit/machine-learning/machine-learning.spec.ts b/test/unit/machine-learning/machine-learning.spec.ts index bb0c5988f1..beddfd74d9 100644 --- a/test/unit/machine-learning/machine-learning.spec.ts +++ b/test/unit/machine-learning/machine-learning.spec.ts @@ -19,10 +19,11 @@ import * as _ from 'lodash'; import * as chai from 'chai'; import * as sinon from 'sinon'; -import { MachineLearning } from '../../../src/machine-learning/machine-learning'; +import { MachineLearning, ModelOptions } from '../../../src/machine-learning/machine-learning'; import { FirebaseApp } from '../../../src/firebase-app'; import * as mocks from '../../resources/mocks'; -import { MachineLearningApiClient } from '../../../src/machine-learning/machine-learning-api-client'; +import { MachineLearningApiClient, + StatusErrorResponse, ModelResponse } from '../../../src/machine-learning/machine-learning-api-client'; import { FirebaseMachineLearningError } from '../../../src/machine-learning/machine-learning-utils'; import { deepCopy } from '../../../src/utils/deep-copy'; @@ -65,6 +66,56 @@ describe('MachineLearning', () => { }, }; + const STATUS_ERROR_RESPONSE: { + code: number; + message: string; + } = { + code: 3, + message: 'Invalid Argument message', + }; + + const OPERATION_RESPONSE: { + name?: string; + done: boolean; + error?: StatusErrorResponse; + response?: { + name: string; + createTime: string; + updateTime: string; + etag: string; + modelHash: string; + displayName?: string; + tags?: string[]; + state?: { + validationError?: { + code: number; + message: string; + }; + published?: boolean; + }; + tfliteModel?: { + gcsTfliteUri: string; + sizeBytes: number; + }; + } + } = { + done: true, + response: MODEL_RESPONSE, + }; + + const OPERATION_RESPONSE_ERROR: { + name?: string; + done: boolean; + error?: { + code: number; + message: string; + } + response?: ModelResponse; + } = { + done: true, + error: STATUS_ERROR_RESPONSE, + }; + const CREATE_TIME_UTC = 'Fri, 07 Feb 2020 23:45:23 GMT'; const UPDATE_TIME_UTC = 'Sat, 08 Feb 2020 23:45:23 GMT'; @@ -110,16 +161,18 @@ describe('MachineLearning', () => { + 'instance.'); }); - it('should reject when initialized without project ID', () => { - // Project ID not set in the environment. - delete process.env.GOOGLE_CLOUD_PROJECT; - delete process.env.GCLOUD_PROJECT; - const noProjectId = 'Failed to determine project ID. Initialize the SDK with service ' - + 'account credentials, or set project ID as an app option. Alternatively, set the ' - + 'GOOGLE_CLOUD_PROJECT environment variable.'; - const rulesWithoutProjectId = new MachineLearning(mockCredentialApp); - return rulesWithoutProjectId.getModel('test') - .should.eventually.rejectedWith(noProjectId); + it('should throw given invalid credential', () => { + const expectedError = 'Failed to initialize Google Cloud Storage client with ' + + 'the available credential. Must initialize the SDK with a certificate credential ' + + 'or application default credentials to use Cloud Storage API.'; + expect(() => { + const machineLearningAny: any = MachineLearning; + return new machineLearningAny(mockCredentialApp).createModel({ + displayName: 'foo', + tfliteModel: { + gcsTfliteUri: 'gs://some-bucket/model.tflite', + }}); + }).to.throw(expectedError); }); it('should not throw given a valid app', () => { @@ -261,4 +314,134 @@ describe('MachineLearning', () => { return machineLearning.deleteModel('1234567'); }); }); + + describe('createModel', () => { + const GCS_TFLITE_URI = 'gs://test-bucket/Firebase/ML/Models/model1.tflite'; + const MODEL_OPTIONS_NO_GCS: ModelOptions = { + displayName: 'display_name', + tags: ['tag1', 'tag2'], + }; + const MODEL_OPTIONS_WITH_GCS: ModelOptions = { + displayName: 'display_name_2', + tags: ['tag3', 'tag4'], + tfliteModel: { + gcsTfliteUri: GCS_TFLITE_URI, + }, + }; + + it('should propagate API errors', () => { + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .rejects(EXPECTED_ERROR); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_NO_GCS) + .should.eventually.be.rejected.and.deep.equal(EXPECTED_ERROR); + }); + + it('should reject when API response is invalid', () => { + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(null); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_WITH_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', 'Cannot read property \'done\' of null'); + }); + + it('should reject when API response does not contain a name', () => { + const op = deepCopy(OPERATION_RESPONSE); + op.response!.name = ''; + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(op); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_NO_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', `Invalid Model response: ${JSON.stringify(op.response)}`); + }); + + it('should reject when API response does not contain a createTime', () => { + const op = deepCopy(OPERATION_RESPONSE); + op.response!.createTime = ''; + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(op); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_NO_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', `Invalid Model response: ${JSON.stringify(op.response)}`); + }); + + it('should reject when API response does not contain a updateTime', () => { + const op = deepCopy(OPERATION_RESPONSE); + op.response!.updateTime = ''; + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(op); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_NO_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', `Invalid Model response: ${JSON.stringify(op.response)}`); + }); + + it('should reject when API response does not contain a displayName', () => { + const op = deepCopy(OPERATION_RESPONSE); + op.response!.displayName = ''; + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(op); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_NO_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', `Invalid Model response: ${JSON.stringify(op.response)}`); + }); + + it('should reject when API response does not contain an etag', () => { + const op = deepCopy(OPERATION_RESPONSE); + op.response!.etag = ''; + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(op); + stubs.push(stub); + return machineLearning.createModel(MODEL_OPTIONS_NO_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', `Invalid Model response: ${JSON.stringify(op.response)}`); + }); + + it('should resolve with Model on success', () => { + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(OPERATION_RESPONSE); + stubs.push(stub); + + return machineLearning.createModel(MODEL_OPTIONS_WITH_GCS) + .then((model) => { + expect(model.modelId).to.equal('1234567'); + expect(model.displayName).to.equal('model_1'); + expect(model.tags).to.deep.equal(['tag_1', 'tag_2']); + expect(model.createTime).to.equal(CREATE_TIME_UTC); + expect(model.updateTime).to.equal(UPDATE_TIME_UTC); + expect(model.validationError).to.be.empty; + expect(model.published).to.be.true; + expect(model.etag).to.equal('etag123'); + expect(model.modelHash).to.equal('modelHash123'); + + const tflite = model.tfliteModel!; + expect(tflite.gcsTfliteUri).to.be.equal( + 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite'); + expect(tflite.sizeBytes).to.be.equal(16900988); + }); + }); + + it('should resolve with Error on operation error', () => { + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'createModel') + .resolves(OPERATION_RESPONSE_ERROR); + stubs.push(stub); + + return machineLearning.createModel(MODEL_OPTIONS_WITH_GCS) + .should.eventually.be.rejected.and.have.property( + 'message', 'Invalid Argument message'); + }); + }); });