Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def create_model(model, app=None):
return Model.from_dict(mlkit_service.create_model(model), app=app)


def update_model(model, app=None):
"""Updates a model in Firebase ML Kit.

Args:
model: The mlkit.Model to update.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The updated model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.update_model(model), app=app)


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Expand Down Expand Up @@ -469,10 +483,10 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
def _validate_model(model, update_mask=None):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
if update_mask is None and not model.display_name:
raise ValueError('Model must have a display name.')


Expand Down Expand Up @@ -634,6 +648,17 @@ def create_model(self, model):
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict()}
if update_mask is not None:
data['updateMask'] = update_mask
try:
return self.handle_operation(
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
try:
Expand Down
127 changes: 119 additions & 8 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tests import testutils

BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'

PROJECT_ID = 'myProject1'
PAGE_TOKEN = 'pageToken'
NEXT_PAGE_TOKEN = 'nextPageToken'
Expand Down Expand Up @@ -122,7 +121,7 @@
}
TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)

CREATED_MODEL_JSON_1 = {
CREATED_UPDATED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
'createTime': CREATE_TIME_JSON,
Expand All @@ -132,7 +131,7 @@
'modelHash': MODEL_HASH,
'tags': TAGS,
}
CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1)
CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1)

LOCKED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
Expand All @@ -155,19 +154,16 @@
OPERATION_DONE_MODEL_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
'response': CREATED_MODEL_JSON_1
'response': CREATED_UPDATED_MODEL_JSON_1
}

OPERATION_MALFORMED_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
# if done is true then either response or error should be populated
}

OPERATION_MISSING_NAME = {
'done': False
}

OPERATION_ERROR_CODE = 400
OPERATION_ERROR_MSG = "Invalid argument"
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
Expand Down Expand Up @@ -524,7 +520,7 @@ def _get_url(project_id, model_id):
def test_immediate_done(self):
instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
model = mlkit.create_model(MODEL_1)
assert model == CREATED_MODEL_1
assert model == CREATED_UPDATED_MODEL_1

def test_returns_locked(self):
recorder = instrument_mlkit_service(
Expand Down Expand Up @@ -615,6 +611,121 @@ def test_invalid_op_name(self, op_name):
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestUpdateModel(object):
"""Tests mlkit.update_model."""
@classmethod
def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test

@classmethod
def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
def _op_url(project_id, model_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)

def test_immediate_done(self):
instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
model = mlkit.update_model(MODEL_1)
assert model == CREATED_UPDATED_MODEL_1

def test_returns_locked(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.update_model(MODEL_1)

assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)

def test_operation_error(self):
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
# The http request succeeded, the operation returned contains a create failure
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.update_model(MODEL_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)

def test_rpc_error_create(self):
create_recorder = instrument_mlkit_service(
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_firebase_error(
excinfo,
ERROR_STATUS_BAD_REQUEST,
ERROR_CODE_BAD_REQUEST,
ERROR_MSG_BAD_REQUEST
)
assert len(create_recorder) == 1

@pytest.mark.parametrize('model', [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is used in multiple tests, turn it into a constant INVALID_MODEL_ARGS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

'abc',
4.2,
list(),
dict(),
True,
-1,
0,
None
])
def test_not_model(self, model):
with pytest.raises(Exception) as excinfo:
mlkit.update_model(model)
check_error(excinfo, TypeError, 'Model must be an mlkit.Model.')

def test_missing_display_name(self):
with pytest.raises(Exception) as excinfo:
mlkit.update_model(mlkit.Model.from_dict({}))
check_error(excinfo, ValueError, 'Model must have a display name.')

def test_missing_op_name(self):
instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_error(excinfo, TypeError)

@pytest.mark.parametrize('op_name', [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. If used more than once INVALID_MODEL_NAMES

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

'abc',
'123',
'projects/operations/project/1234/model/abc/operation/123',
'operations/project/model/abc/operation/123',
'operations/project/123/model/$#@/operation/123',
'operations/project/1234/model/abc/operation/123/extrathing',
])
def test_invalid_op_name(self, op_name):
payload = json.dumps({'name': op_name})
instrument_mlkit_service(status=200, payload=payload)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand Down