Skip to content

Added UpdateModel, publishModel,and unpublishModel functionality + tests #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/machine-learning/machine-learning-api-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ export interface StatusErrorResponse {
readonly message: string;
}

/**
* A Firebase ML Model input object
*/
export interface ModelOptions {
displayName?: string;
tags?: string[];

tfliteModel?: { gcsTfliteUri: string; };
}

export interface ModelUpdateOptions extends ModelOptions {
state?: { published?: boolean; };
}

export interface ModelContent {
readonly displayName?: string;
readonly tags?: string[];
Expand Down Expand Up @@ -80,7 +94,7 @@ export class MachineLearningApiClient {
this.httpClient = new AuthorizedHttpClient(app);
}

public createModel(model: ModelContent): Promise<OperationResponse> {
public createModel(model: ModelOptions): Promise<OperationResponse> {
if (!validator.isNonNullObject(model) ||
!validator.isNonEmptyString(model.displayName)) {
const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model content.');
Expand All @@ -97,6 +111,24 @@ export class MachineLearningApiClient {
});
}

public updateModel(modelId: string, model: ModelUpdateOptions, updateMask: string[]): Promise<OperationResponse> {
if (!validator.isNonEmptyString(modelId) ||
!validator.isNonNullObject(model) ||
!validator.isNonEmptyArray(updateMask)) {
const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model or mask content.');
return Promise.reject(err);
}
return this.getUrl()
.then((url) => {
const request: HttpRequestConfig = {
method: 'PATCH',
url: `${url}/models/${modelId}?updateMask=${updateMask.join()}`,
data: model,
};
return this.sendRequest<OperationResponse>(request);
});
}


public getModel(modelId: string): Promise<ModelResponse> {
return Promise.resolve()
Expand Down
61 changes: 27 additions & 34 deletions src/machine-learning/machine-learning.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import {FirebaseApp} from '../firebase-app';
import {FirebaseServiceInterface, FirebaseServiceInternalsInterface} from '../firebase-service';
import {MachineLearningApiClient, ModelResponse, OperationResponse, ModelContent} from './machine-learning-api-client';
import {MachineLearningApiClient, ModelResponse, OperationResponse,
ModelOptions, ModelUpdateOptions} from './machine-learning-api-client';
import {FirebaseError} from '../utils/error';

import * as validator from '../utils/validator';
import {FirebaseMachineLearningError} from './machine-learning-utils';
import { deepCopy } from '../utils/deep-copy';
import * as utils from '../utils';

/**
* Internals of an ML instance.
Expand Down Expand Up @@ -95,7 +97,7 @@ export class MachineLearning implements FirebaseServiceInterface {
* @return {Promise<Model>} A Promise fulfilled with the created model.
*/
public createModel(model: ModelOptions): Promise<Model> {
return this.convertOptionstoContent(model, true)
return this.signUrlIfPresent(model)
.then((modelContent) => this.client.createModel(modelContent))
.then((operation) => handleOperation(operation));
}
Expand All @@ -109,8 +111,11 @@ export class MachineLearning implements FirebaseServiceInterface {
* @return {Promise<Model>} A Promise fulfilled with the updated model.
*/
public updateModel(modelId: string, model: ModelOptions): Promise<Model> {
throw new Error('NotImplemented');
}
const updateMask = utils.generateUpdateMask(model);
return this.signUrlIfPresent(model)
.then((modelContent) => this.client.updateModel(modelId, modelContent, updateMask))
.then((operation) => handleOperation(operation));
}

/**
* Publishes a model in Firebase ML.
Expand All @@ -120,7 +125,7 @@ export class MachineLearning implements FirebaseServiceInterface {
* @return {Promise<Model>} A Promise fulfilled with the published model.
*/
public publishModel(modelId: string): Promise<Model> {
throw new Error('NotImplemented');
return this.setPublishStatus(modelId, true);
}

/**
Expand All @@ -131,7 +136,7 @@ export class MachineLearning implements FirebaseServiceInterface {
* @return {Promise<Model>} A Promise fulfilled with the unpublished model.
*/
public unpublishModel(modelId: string): Promise<Model> {
throw new Error('NotImplemented');
return this.setPublishStatus(modelId, false);
}

/**
Expand All @@ -143,9 +148,7 @@ export class MachineLearning implements FirebaseServiceInterface {
*/
public getModel(modelId: string): Promise<Model> {
return this.client.getModel(modelId)
.then((modelResponse) => {
return new Model(modelResponse);
});
.then((modelResponse) => new Model(modelResponse));
}

/**
Expand All @@ -171,23 +174,28 @@ export class MachineLearning implements FirebaseServiceInterface {
return this.client.deleteModel(modelId);
}

private convertOptionstoContent(options: ModelOptions, forUpload?: boolean): Promise<ModelContent> {
const modelContent = deepCopy(options);
private setPublishStatus(modelId: string, publish: boolean): Promise<Model> {
const updateMask = ['state.published'];
const options: ModelUpdateOptions = {state: {published: publish}};
return this.client.updateModel(modelId, options, updateMask)
.then((operation) => handleOperation(operation));
}

if (forUpload && modelContent.tfliteModel?.gcsTfliteUri) {
return this.signUrl(modelContent.tfliteModel.gcsTfliteUri)
private signUrlIfPresent(options: ModelOptions): Promise<ModelOptions> {
const modelOptions = deepCopy(options);
if (modelOptions.tfliteModel?.gcsTfliteUri) {
return this.signUrl(modelOptions.tfliteModel.gcsTfliteUri)
.then ((uri: string) => {
modelContent.tfliteModel!.gcsTfliteUri = uri;
return modelContent;
modelOptions.tfliteModel!.gcsTfliteUri = uri;
return modelOptions;
})
.catch((err: Error) => {
throw new FirebaseMachineLearningError(
'internal-error',
`Error during signing upload url: ${err.message}`);
}) as Promise<ModelContent>;
});
}

return Promise.resolve(modelContent) as Promise<ModelContent>;
return Promise.resolve(modelOptions);
}

private signUrl(unsignedUrl: string): Promise<string> {
Expand All @@ -208,9 +216,7 @@ export class MachineLearning implements FirebaseServiceInterface {
return blob.getSignedUrl({
action: 'read',
expires: Date.now() + URL_VALID_DURATION,
}).then((x) => {
return x[0];
});
}).then((signUrl) => signUrl[0]);
}
}

Expand Down Expand Up @@ -287,23 +293,10 @@ export interface TFLiteModel {
readonly gcsTfliteUri: string;
}


/**
* A Firebase ML Model input object
*/
export class ModelOptions {
public displayName?: string;
public tags?: string[];

public tfliteModel?: { gcsTfliteUri: string; };
}


function extractModelId(resourceName: string): string {
return resourceName.split('/').pop()!;
}


function handleOperation(op: OperationResponse): Model {
// Backend currently does not return operations that are not done.
if (op.done) {
Expand Down
178 changes: 177 additions & 1 deletion test/integration/machine-learning.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,178 @@ describe('admin.machineLearning', () => {
});
});

describe('updateModel()', () => {

const UPDATE_NAME: admin.machineLearning.ModelOptions = {
displayName: 'update-model-new-name',
};

it('rejects with not-found when the Model does not exist', () => {
const nonExistingId = '00000000';
return admin.machineLearning().updateModel(nonExistingId, UPDATE_NAME)
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/not-found');
});

it('rejects with invalid-argument when the ModelId is invalid', () => {
return admin.machineLearning().updateModel('invalid-model-id', UPDATE_NAME)
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/invalid-argument');
});

it ('rejects with invalid-argument when modelOptions are invalid', () => {
const modelOptions: admin.machineLearning.ModelOptions = {
displayName: 'Invalid Name#*^!',
};
return createTemporaryModel({displayName: 'node-integration-invalid-argument'})
.then((model) => admin.machineLearning().updateModel(model.modelId, modelOptions)
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/invalid-argument'));
});

it('updates the displayName', () => {
const DISPLAY_NAME = 'node-integration-test-update-1b';
return createTemporaryModel({displayName: 'node-integration-test-update-1a'})
.then((model) => {
const modelOptions: admin.machineLearning.ModelOptions = {
displayName: DISPLAY_NAME,
};
return admin.machineLearning().updateModel(model.modelId, modelOptions)
.then((updatedModel) => {
verifyModel(updatedModel, modelOptions);
});
});
});

it('sets tags for a model', () => {
// TODO(ifielker): Uncomment & replace when BE change lands.
// const ORIGINAL_TAGS = ['tag-node-update-1'];
const ORIGINAL_TAGS: string[] = [];
const NEW_TAGS = ['tag-node-update-2', 'tag-node-update-3'];

return createTemporaryModel({
displayName: 'node-integration-test-update-2',
tags: ORIGINAL_TAGS,
}).then((expectedModel) => {
const modelOptions: admin.machineLearning.ModelOptions = {
tags: NEW_TAGS,
};
return admin.machineLearning().updateModel(expectedModel.modelId, modelOptions)
.then((actualModel) => {
expect(actualModel.tags!.length).to.equal(2);
expect(actualModel.tags).to.have.same.members(NEW_TAGS);
});
});
});

it('updates the tflite file', () => {
Promise.all([
createTemporaryModel(),
uploadModelToGcs('model1.tflite', 'valid_model.tflite')])
.then(([model, fileName]) => {
const modelOptions: admin.machineLearning.ModelOptions = {
tfliteModel: {gcsTfliteUri: fileName},
};
return admin.machineLearning().updateModel(model.modelId, modelOptions)
.then((updatedModel) => {
verifyModel(updatedModel, modelOptions);
});
});
});

it('can update more than 1 field', () => {
const DISPLAY_NAME = 'node-integration-test-update-3b';
const TAGS = ['node-integration-tag-1', 'node-integration-tag-2'];
return createTemporaryModel({displayName: 'node-integration-test-update-3a'})
.then((model) => {
const modelOptions: admin.machineLearning.ModelOptions = {
displayName: DISPLAY_NAME,
tags: TAGS,
};
return admin.machineLearning().updateModel(model.modelId, modelOptions)
.then((updatedModel) => {
expect(updatedModel.displayName).to.equal(DISPLAY_NAME);
expect(updatedModel.tags).to.have.same.members(TAGS);
});
});
});
});

describe('publishModel()', () => {
it('should reject when model does not exist', () => {
const nonExistingName = '00000000';
return admin.machineLearning().publishModel(nonExistingName)
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/not-found');
});

it('rejects with invalid-argument when the ModelId is invalid', () => {
return admin.machineLearning().publishModel('invalid-model-id')
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/invalid-argument');
});

it('publishes the model successfully', () => {
const modelOptions: admin.machineLearning.ModelOptions = {
displayName: 'node-integration-test-publish-1',
tfliteModel: {gcsTfliteUri: 'this will be replaced below'},
};
return uploadModelToGcs('model1.tflite', 'valid_model.tflite')
.then((fileName: string) => {
modelOptions.tfliteModel!.gcsTfliteUri = fileName;
createTemporaryModel(modelOptions)
.then((createdModel) => {
expect(createdModel.validationError).to.be.empty;
expect(createdModel.published).to.be.false;
admin.machineLearning().publishModel(createdModel.modelId)
.then((publishedModel) => {
expect(publishedModel.published).to.be.true;
});
});
});
});
});

describe('unpublishModel()', () => {
it('should reject when model does not exist', () => {
const nonExistingName = '00000000';
return admin.machineLearning().unpublishModel(nonExistingName)
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/not-found');
});

it('rejects with invalid-argument when the ModelId is invalid', () => {
return admin.machineLearning().unpublishModel('invalid-model-id')
.should.eventually.be.rejected.and.have.property(
'code', 'machine-learning/invalid-argument');
});

it('unpublishes the model successfully', () => {
const modelOptions: admin.machineLearning.ModelOptions = {
displayName: 'node-integration-test-unpublish-1',
tfliteModel: {gcsTfliteUri: 'this will be replaced below'},
};
return uploadModelToGcs('model1.tflite', 'valid_model.tflite')
.then((fileName: string) => {
modelOptions.tfliteModel!.gcsTfliteUri = fileName;
createTemporaryModel(modelOptions)
.then((createdModel) => {
expect(createdModel.validationError).to.be.empty;
expect(createdModel.published).to.be.false;
admin.machineLearning().publishModel(createdModel.modelId)
.then((publishedModel) => {
expect(publishedModel.published).to.be.true;
admin.machineLearning().unpublishModel(publishedModel.modelId)
.then((unpublishedModel) => {
expect(unpublishedModel.published).to.be.false;
});
});
});
});
});
});


describe('getModel()', () => {
it('rejects with not-found when the Model does not exist', () => {
const nonExistingName = '00000000';
Expand Down Expand Up @@ -181,7 +353,11 @@ describe('admin.machineLearning', () => {
});

function verifyModel(model: admin.machineLearning.Model, expectedOptions: admin.machineLearning.ModelOptions) {
expect(model.displayName).to.equal(expectedOptions.displayName);
if (expectedOptions.displayName) {
expect(model.displayName).to.equal(expectedOptions.displayName);
} else {
expect(model.displayName).not.to.be.empty;
}
expect(model.createTime).to.not.be.empty;
expect(model.updateTime).to.not.be.empty;
expect(model.etag).to.not.be.empty;
Expand Down
Loading