Skip to content

[PLT-1677] Added get mal import functions to replace old BulkImportRequest class #1909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions libs/labelbox/src/labelbox/schema/annotation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,20 @@ def parent_id(self) -> str:
"""
return self.project().uid

def delete(self) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't use the deletable inheritance to make this work for this class due to this being a hard delete, not flipping a field to False.

"""
Deletes a MALPredictionImport job
"""

query_string = """
mutation deleteModelAssistedLabelingPredictionImportPyApi($id: ID!) {
deleteModelAssistedLabelingPredictionImport(where: { id: $id }) {
id
}
}
"""
self.client.execute(query_string, {"id": self.uid})

@classmethod
def create_from_file(
cls, client: "labelbox.Client", project_id: str, name: str, path: str
Expand Down
51 changes: 51 additions & 0 deletions libs/labelbox/src/labelbox/schema/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_args,
)

from labelbox.schema.annotation_import import LabelImport, MALPredictionImport
from lbox.exceptions import (
InvalidQueryError,
LabelboxError,
Expand Down Expand Up @@ -710,6 +711,56 @@ def _connect_default_labeling_front_end(self, ontology_as_dict: dict):
},
)

def get_mal_prediction_imports(self) -> PaginatedCollection:
"""Returns mal prediction import objects which are used in model-assisted labeling associated with the project.

Returns:
PaginatedCollection
"""

id_param = "projectId"
query_str = """
query getModelAssistedLabelingPredictionImportsPyApi($%s: ID!) {
modelAssistedLabelingPredictionImports(skip: %%d, first: %%d, where: { projectId: $%s }) { %s }}
""" % (
id_param,
id_param,
query.results_query_part(MALPredictionImport),
)

return PaginatedCollection(
self.client,
query_str,
{id_param: self.uid},
["modelAssistedLabelingPredictionImports"],
MALPredictionImport,
)

def get_label_imports(self) -> PaginatedCollection:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Not going to add a delete path for regular label imports. This is because the mutation is internal and it can also put the data row in a bad state i.e. label with no annotations. We also already have SDk methods to delete and requeue labels so not needed.

"""Returns label import objects associated with the project.

Returns:
PaginatedCollection
"""

id_param = "projectId"
query_str = """
query getLabelImportsPyApi($%s: ID!) {
labelImports(skip: %%d, first: %%d, where: { projectId: $%s }) { %s }}
""" % (
id_param,
id_param,
query.results_query_part(LabelImport),
)

return PaginatedCollection(
self.client,
query_str,
{id_param: self.uid},
["labelImports"],
LabelImport,
)

def create_batch(
self,
name: str,
Expand Down
13 changes: 13 additions & 0 deletions libs/labelbox/tests/data/annotation_import/test_label_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def test_get(client, module_project, annotation_import_test_helpers):
annotation_import_test_helpers.check_running_state(label_import, name, url)


def test_get_import_jobs_from_project(client, configured_project):
name = str(uuid.uuid4())
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
label_import = LabelImport.create_from_url(
client=client, project_id=configured_project.uid, name=name, url=url
)
label_import.wait_until_done()

label_imports = list(configured_project.get_label_imports())
assert len(label_imports) == 1
assert label_imports[0].input_file_url == url


@pytest.mark.slow
def test_wait_till_done(client, module_project, predictions):
name = str(uuid.uuid4())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def test_create_with_path_arg(
annotation_import_test_helpers.assert_file_content(
label_import.input_file_url, object_predictions
)


def test_get_mal_import_jobs_from_project(client, configured_project):
name = str(uuid.uuid4())
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
label_import = MALPredictionImport.create(
client=client, id=configured_project.uid, name=name, url=url
)
label_import.wait_until_done()

label_imports = list(configured_project.get_mal_prediction_imports())
assert len(label_imports) == 1
assert label_imports[0].input_file_url == url

label_imports[0].delete()
label_imports = list(configured_project.get_mal_prediction_imports())
assert len(label_imports) == 0
Loading