diff --git a/examples/model_diagnostics/custom-metrics.ipynb b/examples/model_diagnostics/custom-metrics.ipynb new file mode 100644 index 000000000..317de4c79 --- /dev/null +++ b/examples/model_diagnostics/custom-metrics.ipynb @@ -0,0 +1,708 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "mounted-asian", + "metadata": { + "id": "EyNkbpW7ouEf" + }, + "source": [ + " \n", + "\n", + " \n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "attractive-lemon", + "metadata": {}, + "source": [ + "----\n", + "\n", + "# Model Diagnostics - Custom Metrics\n", + "\n", + "\n", + "* The main idea behind custom metrics is that by enabling users to provide specific metrics, that closely align with company objectives, users will be able to iterate faster and more track model quality over time with less noise.\n", + "\n", + "* For example, a self driving car dashcam might track the precision, recall, and iou of their model. However, nearby objects, such as people, might matter much more than far away objects. This important aspect of model quality is lost to the noise of these broad metrics. \n", + "* Continuing with this example, the organization might want to report safety in terms of performance of the model on nearby objects. They maybe also want to know how much marginal value they are getting out of each training example. Both of these use cases rely on the metric specifically tracking the company's target objective.\n", + "\n", + "\n", + "Topics Covered\n", + "* Custom metrics basics\n", + "* Complete diagnostics demo using custom metrics\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "subsequent-magic", + "metadata": { + "id": "subsequent-magic" + }, + "source": [ + "## Environment Setup\n", + "\n", + "Install dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "voluntary-minister", + "metadata": { + "id": "voluntary-minister" + }, + "outputs": [], + "source": [ + "!pip install -q \"labelbox[data]\" \\\n", + " scikit-image \\\n", + " tensorflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "wooden-worship", + "metadata": { + "id": "wooden-worship" + }, + "outputs": [], + "source": [ + "# Run these if running in a colab notebook\n", + "COLAB = \"google.colab\" in str(get_ipython())\n", + "\n", + "if COLAB:\n", + " !git clone https://github.com/Labelbox/labelbox-python.git\n", + " !cd labelbox-python\n", + " !mv labelbox-python/examples/model_assisted_labeling/*.py ." + ] + }, + { + "cell_type": "markdown", + "id": "latter-leone", + "metadata": { + "id": "latter-leone" + }, + "source": [ + "Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "committed-richards", + "metadata": { + "id": "committed-richards" + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../model_assisted_labeling')\n", + "\n", + "import uuid\n", + "import numpy as np\n", + "from skimage import measure\n", + "import requests\n", + "from tqdm import notebook\n", + "import requests\n", + "import csv\n", + "import os\n", + "\n", + "from labelbox.schema.ontology import OntologyBuilder, Tool\n", + "from labelbox.data.metrics.group import get_label_pairs\n", + "from labelbox import Client, LabelingFrontend, MALPredictionImport\n", + "from labelbox.data.metrics.iou import data_row_miou, feature_miou_metric\n", + "from labelbox.data.serialization import NDJsonConverter\n", + "from labelbox.data.annotation_types import (\n", + " ScalarMetric, \n", + " LabelList, \n", + " Label, \n", + " ImageData, \n", + " MaskData,\n", + " Mask, \n", + " Polygon,\n", + " Point, \n", + " Rectangle, \n", + " ObjectAnnotation\n", + ")\n", + "\n", + "try:\n", + " from image_model import predict, load_model\n", + "except ModuleNotFoundError: \n", + " # !git clone https://github.com/Labelbox/labelbox-python.git\n", + " # !cd labelbox-python && git checkout mea-dev\n", + " # !mv labelbox-python/examples/model_assisted_labeling/*.py .\n", + " raise Exception(\"You will need to run from the labelbox-python git repo\")" + ] + }, + { + "cell_type": "markdown", + "id": "enclosed-tribe", + "metadata": {}, + "source": [ + "## Custom Metrics\n", + "* Users can provide metrics to provide metric information at different levels of granularity.\n", + " * Users can provide metrics for \n", + " 1. data rows\n", + " 2. features\n", + " 3. subclasses\n", + " * Additionally, metrics can be given custom names to best describe what they are measuring.\n", + " \n", + "* Limits and Behavior:\n", + " * At a data row cannot have more than 20 metrics\n", + " * Metrics are upserted, so if a metric already exists, its value will be replaced\n", + " * Metrics can have values in the range [0,100000]\n", + "* Currently only `ScalarMetric`s are supported. A `ScalarMetric` is a metric with just a single scalar value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "palestinian-continent", + "metadata": {}, + "outputs": [], + "source": [ + "from labelbox.data.annotation_types import ScalarMetric, MetricAggregation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sudden-danger", + "metadata": {}, + "outputs": [], + "source": [ + "data_row_metric = ScalarMetric(\n", + " metric_name = \"iou\",\n", + " value = 0.5\n", + ")\n", + "\n", + "feature_metric = ScalarMetric(\n", + " metric_name = \"iou\",\n", + " feature_name = \"cat\",\n", + " value = 0.5\n", + ")\n", + "\n", + "subclass_metric = ScalarMetric(\n", + " metric_name = \"iou\",\n", + " feature_name = \"cat\",\n", + " subclass_name = \"organge\",\n", + " value = 0.5\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "circular-gather", + "metadata": {}, + "source": [ + "### Aggregations\n", + "* This is an optional field on the `ScalarMetric` object (by default it uses Arithmetic Mean).\n", + "* Aggregations occur in two cases:\n", + " 1. When a user provides a feature or subclass level metric, Labelbox automatically aggregates all metrics with the same parent to create a value for that parent.\n", + " * E.g. A user provides cat and dog iou. The data row level metric for iou is the average of both of those.\n", + " * The exception to this is when the data row level iou is explicitly set, then the aggregation will not take effect (on a per data row basis). \n", + " 2. When users create slices or want aggregate statistics on their models, the selected aggregation is applied." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acquired-distributor", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "If the following metrics are uploaded then\n", + "in the web app, users will see:\n", + "true positives dog = 4\n", + "true positives cat = 3\n", + "true positives = 7\n", + "\"\"\"\n", + "\n", + "feature_metric = ScalarMetric(\n", + " metric_name = \"true_positives\",\n", + " feature_name = \"cat\",\n", + " value = 3,\n", + " aggregation = MetricAggregation.SUM\n", + ")\n", + "\n", + "feature_metric = ScalarMetric(\n", + " metric_name = \"true_positives\",\n", + " feature_name = \"dog\",\n", + " value = 4,\n", + " aggregation = MetricAggregation.SUM\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "tropical-ambassador", + "metadata": {}, + "source": [ + "## Complete Example\n", + "* Custom metrics are uploaded exactly the same way that iou was previously uploaded.\n", + "* A metric must be added to a `Label` to create an association between a data row and the metric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "economic-chase", + "metadata": { + "id": "economic-chase" + }, + "outputs": [], + "source": [ + "API_KEY = None\n", + "PROJECT_NAME = \"Diagnostics Demo Custom Metrics\"\n", + "MODEL_NAME = \"MSCOCO-Mapillary-Custom-Metrics\"\n", + "MODEL_VERSION = \"0.0.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "affecting-myanmar", + "metadata": { + "id": "affecting-myanmar" + }, + "outputs": [], + "source": [ + "client = Client(api_key=API_KEY)\n", + "load_model() # initialize Tensorflow Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "modern-program", + "metadata": { + "id": "modern-program" + }, + "outputs": [], + "source": [ + "# Configure for whatever combination of tools and class names that you would like.\n", + "class_mappings = {\n", + " 1: {\"name\": 'person', \"kind\": Tool.Type.POLYGON},\n", + " 2: {\"name\": 'bicycle', \"kind\": Tool.Type.SEGMENTATION, 'color' : 64},\n", + " 3: {\"name\": 'car', \"kind\": Tool.Type.BBOX},\n", + " 4: {\"name\": 'motorcycle', \"kind\": Tool.Type.BBOX},\n", + " 6: {\"name\": 'bus', \"kind\": Tool.Type.POLYGON},\n", + " 7: {\"name\": 'train', \"kind\": Tool.Type.POLYGON},\n", + " 8: {\"name\": 'truck', \"kind\": Tool.Type.POLYGON},\n", + " 10: {\"name\": 'traffic light', \"kind\": Tool.Type.POINT},\n", + " 11: {\"name\": 'fire hydrant', \"kind\": Tool.Type.BBOX},\n", + " 13: {\"name\": 'stop sign', \"kind\": Tool.Type.SEGMENTATION, 'color' : 255},\n", + " 14: {\"name\": 'parking meter', \"kind\": Tool.Type.POINT},\n", + " 28: {\"name\": 'umbrella', \"kind\": Tool.Type.SEGMENTATION, 'color' : 128}, \n", + " 31: {\"name\": 'handbag', \"kind\": Tool.Type.POINT}, \n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "dated-burden", + "metadata": { + "id": "dated-burden" + }, + "source": [ + "## Create Predictions\n", + "* Loop over data_rows, make predictions, and create ndjson" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "blank-flower", + "metadata": {}, + "outputs": [], + "source": [ + "# --- setup dataset ---\n", + "# load mapillary sample\n", + "sample_csv_url = \"https://raw.githubusercontent.com/Labelbox/labelbox-python/develop/examples/assets/mapillary_sample.csv\"\n", + "with requests.get(sample_csv_url, stream=True) as r:\n", + " image_data = [row.split(',') for row in (line.decode('utf-8') for line in r.iter_lines())]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "asian-savings", + "metadata": { + "id": "asian-savings" + }, + "outputs": [], + "source": [ + "predictions = LabelList()\n", + "for (image_url, external_id) in notebook.tqdm(image_data[:5]):\n", + " image = ImageData(url = image_url, external_id = external_id)\n", + " height, width = image.value.shape[:2]\n", + " prediction = predict(np.array([image.im_bytes]), min_score=0.5, height=height, width = width)\n", + " boxes, classes, seg_masks = prediction[\"boxes\"], prediction[\"class_indices\"], prediction[\"seg_masks\"]\n", + " annotations = []\n", + " for box, class_idx, seg in zip(boxes, classes, seg_masks):\n", + " if class_idx in class_mappings:\n", + " class_info = class_mappings.get(class_idx)\n", + " if class_info['kind'] == Tool.Type.POLYGON:\n", + " contours = measure.find_contours(seg, 0.5)\n", + " pts = contours[0].astype(np.int32)\n", + " value = Polygon(points = [Point(x = x, y = y) for x,y in np.roll(pts, 1, axis=-1)])\n", + " elif class_info['kind'] == Tool.Type.BBOX:\n", + " value = Rectangle(start = Point(x = box[1], y = box[0]), end = Point(x=box[3], y=box[2]))\n", + " elif class_info['kind'] == Tool.Type.POINT:\n", + " value = Point(x=(box[1] + box[3]) / 2., y = (box[0] + box[2]) / 2.)\n", + " elif class_info['kind'] == Tool.Type.SEGMENTATION:\n", + " value = Mask(mask = MaskData.from_2D_arr(seg * class_info['color']), color = (class_info['color'],)* 3)\n", + " else:\n", + " raise ValueError(f\"Unsupported kind found. {class_info['kind']}\")\n", + " annotations.append(ObjectAnnotation(name = class_info['name'], value = value))\n", + " predictions.append(Label(data = image, annotations = annotations))" + ] + }, + { + "cell_type": "markdown", + "id": "together-suicide", + "metadata": {}, + "source": [ + "## Setup a project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "received-nigeria", + "metadata": {}, + "outputs": [], + "source": [ + "tools = []\n", + "for target in class_mappings.values():\n", + " tools.append(Tool(tool=target['kind'], name=target[\"name\"]))\n", + "ontology_builder = OntologyBuilder(tools=tools)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "stopped-phrase", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Setting up: {PROJECT_NAME}\")\n", + "\n", + "project = client.create_project(name=PROJECT_NAME)\n", + "editor = next(client.get_labeling_frontends(where=LabelingFrontend.name == \"Editor\"))\n", + "project.setup(editor, ontology_builder.asdict())\n", + "\n", + "dataset = client.create_dataset(name=\"Mapillary Diagnostics Demo\")\n", + "print(f\"Dataset Created: {dataset.uid}\")\n", + "project.datasets.connect(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "secure-shelf", + "metadata": {}, + "source": [ + "## Prepare for upload\n", + "* Our local annotations need the following:\n", + " 1. signed url for segmentation masks\n", + " 2. data rows in labelbox\n", + " 3. feature schema ids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "unavailable-egyptian", + "metadata": {}, + "outputs": [], + "source": [ + "signer = lambda _bytes: client.upload_data(content=_bytes, sign=True)\n", + "predictions.add_url_to_masks(signer) \\\n", + " .add_url_to_data(signer) \\\n", + " .assign_feature_schema_ids(OntologyBuilder.from_project(project)) \\\n", + " .add_to_dataset(dataset, client.upload_data)" + ] + }, + { + "cell_type": "markdown", + "id": "perfect-seafood", + "metadata": { + "id": "perfect-seafood" + }, + "source": [ + "## **Optional** - Create labels with [Model Assisted Labeling](https://docs.labelbox.com/en/core-concepts/model-assisted-labeling)\n", + "\n", + "* Pre-label image so that we can quickly create ground truth\n", + "* Create ground truth data for Model Diagnostics\n", + "* Click on link below to label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "subject-painting", + "metadata": { + "id": "subject-painting" + }, + "outputs": [], + "source": [ + "RUN_MAL = True\n", + "if RUN_MAL:\n", + " project.enable_model_assisted_labeling()\n", + " # Convert from annotation types to import format\n", + " ndjson_predictions = NDJsonConverter.serialize(predictions)\n", + " upload_task = MALPredictionImport.create_from_objects(client, project.uid, f'mal-import-{uuid.uuid4()}',ndjson_predictions )\n", + " upload_task.wait_until_done()\n", + " print(upload_task.state , '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "MV4U1W4H_eMq", + "metadata": { + "id": "MV4U1W4H_eMq" + }, + "outputs": [], + "source": [ + "print(f\"https://app.labelbox.com/go-label/{project.uid}\")" + ] + }, + { + "cell_type": "markdown", + "id": "stopped-mandate", + "metadata": { + "id": "stopped-mandate" + }, + "source": [ + "## Export Labels\n", + "\n", + "We do not support `Skipped` labels and have a limit of **2000**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "excited-seminar", + "metadata": { + "id": "excited-seminar" + }, + "outputs": [], + "source": [ + "MAX_LABELS = 2000\n", + "labels = [l for idx, l in enumerate(project.label_generator()) if idx < MAX_LABELS]" + ] + }, + { + "cell_type": "markdown", + "id": "smoking-catering", + "metadata": { + "id": "smoking-catering" + }, + "source": [ + "## Setup Model & Model Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mental-minnesota", + "metadata": { + "id": "mental-minnesota" + }, + "outputs": [], + "source": [ + "lb_model = client.create_model(name = MODEL_NAME, ontology_id = project.ontology().uid)\n", + "lb_model_run = lb_model.create_model_run(MODEL_VERSION)" + ] + }, + { + "cell_type": "markdown", + "id": "cu8h6h0g-Fe2", + "metadata": { + "id": "cu8h6h0g-Fe2" + }, + "source": [ + "Select label ids to upload" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "static-coordinate", + "metadata": { + "id": "static-coordinate" + }, + "outputs": [], + "source": [ + "lb_model_run.upsert_labels([label.uid for label in labels])" + ] + }, + { + "cell_type": "markdown", + "id": "g_u1ak2n7qn5", + "metadata": { + "id": "g_u1ak2n7qn5" + }, + "source": [ + "### Compute Metrics\n", + "* First get pairs of labels and predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "hungry-collective", + "metadata": {}, + "outputs": [], + "source": [ + "pairs = get_label_pairs(labels, predictions, filter_mismatch = True)" + ] + }, + { + "cell_type": "markdown", + "id": "talented-netherlands", + "metadata": {}, + "source": [ + "* Create helper functions for our metrics\n", + "* All functions will accept ground truth and prediction annotations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "muslim-telling", + "metadata": {}, + "outputs": [], + "source": [ + "from shapely.ops import cascaded_union\n", + "\n", + "def nearby_cars_iou(ground_truths, predictions, area_threshold = 17000):\n", + " \"\"\"\n", + " Metric to track the iou score for cars that are nearby (determined by pixel size).\n", + " \n", + " This might be useful to investigate why the model poorly when vehicles are nearby.\n", + " Or this might just be a metric we care a lot about optimizing because our self driving car needs to \n", + " be aware of its immediate surroundings for safety reasons.\n", + " \"\"\"\n", + " ground_truths = [gt for gt in ground_truths if gt.name == 'car']\n", + " predictions = [pred for pred in predictions if pred.name == 'car']\n", + " ground_truths = cascaded_union([gt.value.shapely for gt in ground_truths if gt.value.shapely.area > area_threshold])\n", + " predictions = cascaded_union([pred.value.shapely for pred in predictions if pred.value.shapely.area > area_threshold])\n", + " union = ground_truths.union(predictions).area\n", + " # If there is no prediction or label then the score is undefined\n", + " if union == 0:\n", + " return []\n", + " return [ScalarMetric(\n", + " value = ground_truths.intersection(predictions).area / union,\n", + " metric_name = \"iou\",\n", + " feature_name = \"car\",\n", + " subclass_name = \"nearby\" # Doesn't necessarily need to be a subclass in the ontology\n", + " )] \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "voluntary-rugby", + "metadata": {}, + "source": [ + "* Compute and sssign each metric to prediction label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "committed-fairy", + "metadata": { + "id": "committed-fairy" + }, + "outputs": [], + "source": [ + "for (ground_truth, prediction) in pairs.values():\n", + " metrics = []\n", + " metrics.extend(feature_miou_metric(ground_truth.annotations, prediction.annotations))\n", + " metrics.extend(nearby_cars_iou(ground_truth.annotations, prediction.annotations))\n", + " prediction.annotations.extend(metrics)" + ] + }, + { + "cell_type": "markdown", + "id": "eastern-illinois", + "metadata": {}, + "source": [ + "### Upload to Labelbox" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "anonymous-addition", + "metadata": { + "id": "anonymous-addition" + }, + "outputs": [], + "source": [ + "upload_task = lb_model_run.add_predictions(f'diagnostics-import-{uuid.uuid4()}', NDJsonConverter.serialize(predictions))\n", + "upload_task.wait_until_done()\n", + "print(upload_task.state)" + ] + }, + { + "cell_type": "markdown", + "id": "uTjGOyIW-3op", + "metadata": { + "id": "uTjGOyIW-3op" + }, + "source": [ + "### Open Model Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "zrll9K6Q9tGK", + "metadata": { + "id": "zrll9K6Q9tGK" + }, + "outputs": [], + "source": [ + "for idx, annotation_group in enumerate(lb_model_run.annotation_groups()):\n", + " if idx == 5:\n", + " break\n", + " print(annotation_group.url)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Model Diagnostics Demo", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/model_diagnostics/model_diagnostics_demo.ipynb b/examples/model_diagnostics/model_diagnostics_demo.ipynb index 8c2c58848..91ac9cc97 100644 --- a/examples/model_diagnostics/model_diagnostics_demo.ipynb +++ b/examples/model_diagnostics/model_diagnostics_demo.ipynb @@ -73,7 +73,7 @@ }, "outputs": [], "source": [ - "!pip install \"labelbox[data]\" \\\n", + "!pip install -q \"labelbox[data]\" \\\n", " scikit-image \\\n", " tensorflow" ] @@ -128,8 +128,9 @@ "import os\n", "\n", "from labelbox.schema.ontology import OntologyBuilder, Tool\n", + "from labelbox.data.metrics.group import get_label_pairs\n", "from labelbox import Client, LabelingFrontend, MALPredictionImport\n", - "from labelbox.data.metrics.iou import data_row_miou\n", + "from labelbox.data.metrics.iou import data_row_miou, feature_miou_metric\n", "from labelbox.data.serialization import NDJsonConverter\n", "from labelbox.data.annotation_types import (\n", " ScalarMetric, \n", @@ -253,7 +254,7 @@ "outputs": [], "source": [ "predictions = LabelList()\n", - "for (image_url, external_id) in notebook.tqdm(image_data):\n", + "for (image_url, external_id) in notebook.tqdm(image_data[:10]):\n", " image = ImageData(url = image_url, external_id = external_id)\n", " height, width = image.value.shape[:2]\n", " prediction = predict(np.array([image.im_bytes]), min_score=0.5, height=height, width = width)\n", @@ -481,21 +482,17 @@ }, "outputs": [], "source": [ - "label_lookup = {label.data.uid : label for label in labels}\n", - "\n", - "for pred in predictions:\n", - " label = label_lookup.get(pred.data.uid)\n", - " if label is None:\n", - " # No label that matches the prediction\n", - " continue\n", - " \n", - " score = data_row_miou(label, pred)\n", - " if score is None:\n", - " continue\n", - " \n", - " pred.annotations.append(\n", - " ScalarMetric(value = score)\n", - " )" + "pairs = get_label_pairs(labels, predictions, filter_mismatch = True)\n", + "for (label, prediction) in pairs.values():\n", + " prediction.annotations.extend(feature_miou_metric(label.annotations, prediction.annotations))" + ] + }, + { + "cell_type": "markdown", + "id": "devoted-vatican", + "metadata": {}, + "source": [ + "### Upload to Labelbox" ] }, { @@ -540,7 +537,7 @@ { "cell_type": "code", "execution_count": null, - "id": "wound-newfoundland", + "id": "martial-kenya", "metadata": {}, "outputs": [], "source": [] diff --git a/examples/model_diagnostics/model_diagnostics_guide.ipynb b/examples/model_diagnostics/model_diagnostics_guide.ipynb index 816aa950b..e5ca9e01b 100644 --- a/examples/model_diagnostics/model_diagnostics_guide.ipynb +++ b/examples/model_diagnostics/model_diagnostics_guide.ipynb @@ -351,21 +351,9 @@ }, "outputs": [], "source": [ - "label_lookup = {label.data.uid : label for label in labels}\n", - "\n", - "for pred in predictions:\n", - " label = label_lookup.get(pred.data.uid)\n", - " if label is None:\n", - " # No label for the prediction..\n", - " continue\n", - "\n", - " score = data_row_miou(label, pred)\n", - " if score is None:\n", - " continue\n", - " \n", - " pred.annotations.append(\n", - " ScalarMetric(value = score)\n", - " )" + "pairs = get_label_pairs(labels, predictions, filter_mismatch = True)\n", + "for (label, prediction) in pairs.values():\n", + " prediction.annotations.extend(feature_miou_metric(label.annotations, prediction.annotations))" ] }, { diff --git a/labelbox/data/metrics/group.py b/labelbox/data/metrics/group.py index f2d456193..627274468 100644 --- a/labelbox/data/metrics/group.py +++ b/labelbox/data/metrics/group.py @@ -67,7 +67,7 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: def get_label_pairs(labels_a: LabelList, labels_b: LabelList, match_on="uid", - filter=False) -> Dict[str, Tuple[Label, Label]]: + filter_mismatch=False) -> Dict[str, Tuple[Label, Label]]: """ This is a function to pairing a list of prediction labels and a list of ground truth labels easier. There are a few potentiall problems with this function. @@ -79,7 +79,7 @@ def get_label_pairs(labels_a: LabelList, labels_a (LabelList): A collection of labels to match with labels_b labels_b (LabelList): A collection of labels to match with labels_a match_on ('uid' or 'external_id'): The data row key to match labels by. Can either be uid or external id. - filter (bool): Whether or not to ignore mismatches + filter_mismatch (bool): Whether or not to ignore mismatches Returns: A dict containing the union of all either uids or external ids and values as a tuple of the matched labels @@ -106,14 +106,14 @@ def get_label_pairs(labels_a: LabelList, for key in all_keys: a, b = label_lookup_a.pop(key, None), label_lookup_b.pop(key, None) if a is None or b is None: - if not filter: + if not filter_mismatch: raise ValueError( f"{match_on} {key} is not available in both LabelLists. " - "Set `filter = True` to filter out these examples, assign the ids manually, or create your own matching function." + "Set `filter_mismatch = True` to filter out these examples, assign the ids manually, or create your own matching function." ) else: continue - pairs[key].append([a, b]) + pairs[key].extend([a, b]) return pairs