Skip to content

Commit 795c266

Browse files
authored
Merge pull request #266 from Labelbox/fs/mea-enable-model-model-runs-annotation-groups-deletions
[DIAG-635] Enable interfaces to execute MEA deletions
2 parents a8fcd34 + 42c178d commit 795c266

File tree

10 files changed

+147
-50
lines changed

10 files changed

+147
-50
lines changed

labelbox/schema/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,14 @@ def create_model_run(self, name):
3434
model_id_param: self.uid
3535
})
3636
return ModelRun(self.client, res["createModelRun"])
37+
38+
def delete(self):
39+
""" Deletes specified model.
40+
41+
Returns:
42+
Query execution success.
43+
"""
44+
ids_param = "ids"
45+
query_str = """mutation DeleteModelPyApi($%s: ID!) {
46+
deleteModels(where: {ids: [$%s]})}""" % (ids_param, ids_param)
47+
self.client.execute(query_str, {ids_param: str(self.uid)})

labelbox/schema/model_run.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,36 @@ def annotation_groups(self):
7474
lambda client, res: AnnotationGroup(client, self.model_id, res),
7575
['annotationGroups', 'pageInfo', 'endCursor'])
7676

77+
def delete(self):
78+
""" Deletes specified model run.
79+
80+
Returns:
81+
Query execution success.
82+
"""
83+
ids_param = "ids"
84+
query_str = """mutation DeleteModelRunPyApi($%s: ID!) {
85+
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param)
86+
self.client.execute(query_str, {ids_param: str(self.uid)})
87+
88+
def delete_annotation_groups(self, data_row_ids):
89+
""" Deletes annotation groups by data row ids for a model run.
90+
91+
Args:
92+
data_row_ids (list): List of data row ids to delete annotation groups.
93+
Returns:
94+
Query execution success.
95+
"""
96+
model_run_id_param = "modelRunId"
97+
data_row_ids_param = "dataRowIds"
98+
query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) {
99+
deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % (
100+
model_run_id_param, data_row_ids_param, model_run_id_param,
101+
data_row_ids_param)
102+
self.client.execute(query_str, {
103+
model_run_id_param: self.uid,
104+
data_row_ids_param: data_row_ids
105+
})
106+
77107

78108
class AnnotationGroup(DbObject):
79109
label_id = Field.String("label_id")

tests/integration/bulk_import/conftest.py renamed to tests/integration/mal_and_mea/conftest.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,25 @@ def predictions(object_predictions, classification_predictions):
297297

298298

299299
@pytest.fixture
300-
def model_run(client, rand_gen, configured_project, annotation_submit_fn,
301-
model_run_predictions):
302-
configured_project.enable_model_assisted_labeling()
300+
def model(client, rand_gen, configured_project):
303301
ontology = configured_project.ontology()
304302

303+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
304+
return client.create_model(data["name"], data["ontology_id"])
305+
306+
307+
@pytest.fixture
308+
def model_run(rand_gen, model):
309+
name = rand_gen(str)
310+
return model.create_model_run(name)
311+
312+
313+
@pytest.fixture
314+
def model_run_annotation_groups(client, configured_project,
315+
annotation_submit_fn, model_run_predictions,
316+
model_run):
317+
configured_project.enable_model_assisted_labeling()
318+
305319
upload_task = MALPredictionImport.create_from_objects(
306320
client, configured_project.uid, f'mal-import-{uuid.uuid4()}',
307321
model_run_predictions)
@@ -310,15 +324,10 @@ def model_run(client, rand_gen, configured_project, annotation_submit_fn,
310324
for data_row_id in {x['dataRow']['id'] for x in model_run_predictions}:
311325
annotation_submit_fn(configured_project.uid, data_row_id)
312326

313-
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
314-
model = client.create_model(data["name"], data["ontology_id"])
315-
name = rand_gen(str)
316-
model_run_s = model.create_model_run(name)
317-
318327
time.sleep(3)
319328
labels = configured_project.export_labels(download=True)
320-
model_run_s.upsert_labels([label['ID'] for label in labels])
329+
model_run.upsert_labels([label['ID'] for label in labels])
321330
time.sleep(3)
322331

323-
yield model_run_s
332+
yield model_run
324333
# TODO: Delete resources when that is possible ..

tests/integration/bulk_import/test_mea_annotation_import.py renamed to tests/integration/mal_and_mea/test_mea_annotation_import.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,56 +20,58 @@ def check_running_state(req, name, url=None):
2020
assert req.state == AnnotationImportState.RUNNING
2121

2222

23-
def test_create_from_url(model_run):
23+
def test_create_from_url(model_run_annotation_groups):
2424
name = str(uuid.uuid4())
2525
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
26-
annotation_import = model_run.add_predictions(name=name, predictions=url)
27-
assert annotation_import.model_run_id == model_run.uid
26+
annotation_import = model_run_annotation_groups.add_predictions(
27+
name=name, predictions=url)
28+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
2829
check_running_state(annotation_import, name, url)
2930

3031

31-
def test_create_from_objects(model_run, object_predictions):
32+
def test_create_from_objects(model_run_annotation_groups, object_predictions):
3233
name = str(uuid.uuid4())
3334

34-
annotation_import = model_run.add_predictions(
35+
annotation_import = model_run_annotation_groups.add_predictions(
3536
name=name, predictions=object_predictions)
3637

37-
assert annotation_import.model_run_id == model_run.uid
38+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
3839
check_running_state(annotation_import, name)
3940
assert_file_content(annotation_import.input_file_url, object_predictions)
4041

4142

42-
def test_create_from_local_file(tmp_path, model_run, object_predictions):
43+
def test_create_from_local_file(tmp_path, model_run_annotation_groups,
44+
object_predictions):
4345
name = str(uuid.uuid4())
4446
file_name = f"{name}.ndjson"
4547
file_path = tmp_path / file_name
4648
with file_path.open("w") as f:
4749
ndjson.dump(object_predictions, f)
4850

49-
annotation_import = model_run.add_predictions(name=name,
50-
predictions=str(file_path))
51+
annotation_import = model_run_annotation_groups.add_predictions(
52+
name=name, predictions=str(file_path))
5153

52-
assert annotation_import.model_run_id == model_run.uid
54+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
5355
check_running_state(annotation_import, name)
5456
assert_file_content(annotation_import.input_file_url, object_predictions)
5557

5658

57-
def test_get(client, model_run):
59+
def test_get(client, model_run_annotation_groups):
5860
name = str(uuid.uuid4())
5961
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
60-
model_run.add_predictions(name=name, predictions=url)
62+
model_run_annotation_groups.add_predictions(name=name, predictions=url)
6163

6264
annotation_import = MEAPredictionImport.from_name(
63-
client, model_run_id=model_run.uid, name=name)
65+
client, model_run_id=model_run_annotation_groups.uid, name=name)
6466

65-
assert annotation_import.model_run_id == model_run.uid
67+
assert annotation_import.model_run_id == model_run_annotation_groups.uid
6668
check_running_state(annotation_import, name, url)
6769

6870

6971
@pytest.mark.slow
70-
def test_wait_till_done(model_run_predictions, model_run):
72+
def test_wait_till_done(model_run_predictions, model_run_annotation_groups):
7173
name = str(uuid.uuid4())
72-
annotation_import = model_run.add_predictions(
74+
annotation_import = model_run_annotation_groups.add_predictions(
7375
name=name, predictions=model_run_predictions)
7476

7577
assert len(annotation_import.inputs) == len(model_run_predictions)

tests/integration/test_model.py renamed to tests/integration/mal_and_mea/test_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,14 @@ def test_model(client, configured_project, rand_gen):
1818

1919
model = client.get_model(model.uid)
2020
assert model.name == data["name"]
21+
22+
23+
def test_model_delete(client, model):
24+
before = list(client.get_models())
25+
26+
model = before[0]
27+
model.delete()
28+
29+
after = list(client.get_models())
30+
31+
assert len(before) == len(after) + 1
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import time
2+
3+
4+
def test_model_run(client, configured_project_with_label, rand_gen):
5+
project = configured_project_with_label
6+
ontology = project.ontology()
7+
data = {"name": rand_gen(str), "ontology_id": ontology.uid}
8+
model = client.create_model(data["name"], data["ontology_id"])
9+
10+
name = rand_gen(str)
11+
model_run = model.create_model_run(name)
12+
assert model_run.name == name
13+
assert model_run.model_id == model.uid
14+
assert model_run.created_by_id == client.get_user().uid
15+
16+
label = project.export_labels(download=True)[0]
17+
model_run.upsert_labels([label['ID']])
18+
time.sleep(3)
19+
20+
annotation_group = next(model_run.annotation_groups())
21+
assert annotation_group.label_id == label['ID']
22+
assert annotation_group.model_run_id == model_run.uid
23+
assert annotation_group.data_row().uid == next(
24+
next(project.datasets()).data_rows()).uid
25+
26+
27+
def test_model_run_delete(client, model_run):
28+
models_before = list(client.get_models())
29+
model_before = models_before[0]
30+
before = list(model_before.model_runs())
31+
32+
model_run = before[0]
33+
model_run.delete()
34+
35+
models_after = list(client.get_models())
36+
model_after = models_after[0]
37+
after = list(model_after.model_runs())
38+
39+
assert len(before) == len(after) + 1
40+
41+
42+
def test_model_run_annotation_groups_delete(client,
43+
model_run_annotation_groups):
44+
models = list(client.get_models())
45+
model = models[0]
46+
model_runs = list(model.model_runs())
47+
model_run = model_runs[0]
48+
49+
before = list(model_run.annotation_groups())
50+
annotation_group = before[0]
51+
52+
data_row_id = annotation_group.data_row().uid
53+
model_run.delete_annotation_groups(data_row_ids=[data_row_id])
54+
55+
after = list(model_run.annotation_groups())
56+
57+
assert len(before) == len(after) + 1

tests/integration/test_data_row_metadata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def test_large_bulk_delete_datarow_metadata(big_dataset, mdo):
156156
break
157157

158158

159+
@pytest.mark.skip
159160
def test_bulk_delete_datarow_enum_metadata(datarow: DataRow, mdo):
160161
"""test bulk deletes for non non fields"""
161162
n_fields = len(datarow.metadata["fields"])

tests/integration/test_model_run.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)