diff --git a/CHANGELOG.md b/CHANGELOG.md index 252db04681..7f7b9559cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## [Unreleased] ### Added +- [#1952](https://github.com/plotly/dash/pull/1952) Improved callback_context + - Closes [#1818](https://github.com/plotly/dash/issues/1818) Closes [#1054](https://github.com/plotly/dash/issues/1054) + - adds `dash.ctx`, a more concise name for `dash.callback_context` + - adds `ctx.triggered_prop_ids`, a dictionary of the component ids and props that triggered the callback. + - adds `ctx.triggered_id`, the `id` of the component that triggered the callback. + - adds `ctx.args_grouping`, a dict of the inputs used with flexible callback signatures. - [#2009](https://github.com/plotly/dash/pull/2009) Add support for Promises within Client-side callbacks as requested in [#1364](https://github.com/plotly/dash/pull/1364). diff --git a/dash/__init__.py b/dash/__init__.py index 7e59891b8b..c90d6463c7 100644 --- a/dash/__init__.py +++ b/dash/__init__.py @@ -26,3 +26,4 @@ get_relative_path, strip_relative_path, ) +ctx = callback_context diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 99dc1d563d..20f0e887b6 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,9 +1,11 @@ import functools import warnings - +import json +from copy import deepcopy import flask from . import exceptions +from ._utils import stringify_id, AttributeDict def has_context(func): @@ -46,16 +48,158 @@ def states(self): @property @has_context def triggered(self): + """ + Returns a list of all the Input props that changed and caused the callback to execute. It is empty when the + callback is called on initial load, unless an Input prop got its value from another initial callback. + Callbacks triggered by user actions typically have one item in triggered, unless the same action changes + two props at once or the callback has several Input props that are all modified by another callback based on + a single user action. + + Example: To get the id of the component that triggered the callback: + `component_id = ctx.triggered[0]['prop_id'].split('.')[0]` + + Example: To detect initial call, empty triggered is not really empty, it's falsy so that you can use: + `if ctx.triggered:` + """ # For backward compatibility: previously `triggered` always had a # value - to avoid breaking existing apps, add a dummy item but # make the list still look falsy. So `if ctx.triggered` will make it # look empty, but you can still do `triggered[0]["prop_id"].split(".")` return getattr(flask.g, "triggered_inputs", []) or falsy_triggered + @property + @has_context + def triggered_prop_ids(self): + """ + Returns a dictionary of all the Input props that changed and caused the callback to execute. It is empty when + the callback is called on initial load, unless an Input prop got its value from another initial callback. + Callbacks triggered by user actions typically have one item in triggered, unless the same action changes + two props at once or the callback has several Input props that are all modified by another callback based + on a single user action. + + triggered_prop_ids (dict): + - keys (str) : the triggered "prop_id" composed of "component_id.component_property" + - values (str or dict): the id of the component that triggered the callback. Will be the dict id for pattern matching callbacks + + Example - regular callback + {"btn-1.n_clicks": "btn-1"} + + Example - pattern matching callbacks: + {'{"index":0,"type":"filter-dropdown"}.value': {"index":0,"type":"filter-dropdown"}} + + Example usage: + `if "btn-1.n_clicks" in ctx.triggered_prop_ids: + do_something()` + """ + triggered = getattr(flask.g, "triggered_inputs", []) + ids = AttributeDict({}) + for item in triggered: + component_id, _, _ = item["prop_id"].rpartition(".") + ids[item["prop_id"]] = component_id + if component_id.startswith("{"): + ids[item["prop_id"]] = AttributeDict(json.loads(component_id)) + return ids + + @property + @has_context + def triggered_id(self): + """ + Returns the component id (str or dict) of the Input component that triggered the callback. + + Note - use `triggered_prop_ids` if you need both the component id and the prop that triggered the callback or if + multiple Inputs triggered the callback. + + Example usage: + `if "btn-1" == ctx.triggered_id: + do_something()` + + """ + component_id = None + if self.triggered: + prop_id = self.triggered_prop_ids.first() + component_id = self.triggered_prop_ids[prop_id] + return component_id + @property @has_context def args_grouping(self): - return getattr(flask.g, "args_grouping", []) + """ + args_grouping is a dict of the inputs used with flexible callback signatures. The keys are the variable names + and the values are dictionaries containing: + - “id”: (string or dict) the component id. If it’s a pattern matching id, it will be a dict. + - “id_str”: (str) for pattern matching ids, it’s the strigified dict id with no white spaces. + - “property”: (str) The component property used in the callback. + - “value”: the value of the component property at the time the callback was fired. + - “triggered”: (bool)Whether this input triggered the callback. + + Example usage: + @app.callback( + Output("container", "children"), + inputs=dict(btn1=Input("btn-1", "n_clicks"), btn2=Input("btn-2", "n_clicks")), + ) + def display(btn1, btn2): + c = ctx.args_grouping + if c.btn1.triggered: + return f"Button 1 clicked {btn1} times" + elif c.btn2.triggered: + return f"Button 2 clicked {btn2} times" + else: + return "No clicks yet" + + """ + triggered = getattr(flask.g, "triggered_inputs", []) + triggered = [item["prop_id"] for item in triggered] + grouping = getattr(flask.g, "args_grouping", {}) + + def update_args_grouping(g): + if isinstance(g, dict) and "id" in g: + str_id = stringify_id(g["id"]) + prop_id = f"{str_id}.{g['property']}" + + new_values = { + "value": g.get("value"), + "str_id": str_id, + "triggered": prop_id in triggered, + "id": AttributeDict(g["id"]) + if isinstance(g["id"], dict) + else g["id"], + } + g.update(new_values) + + def recursive_update(g): + if isinstance(g, (tuple, list)): + for i in g: + update_args_grouping(i) + recursive_update(i) + if isinstance(g, dict): + for i in g.values(): + update_args_grouping(i) + recursive_update(i) + + recursive_update(grouping) + + return grouping + + # todo not sure whether we need this, but it removes a level of nesting so + # you don't need to use `.value` to get the value. + @property + @has_context + def args_grouping_values(self): + grouping = getattr(flask.g, "args_grouping", {}) + grouping = deepcopy(grouping) + + def recursive_update(g): + if isinstance(g, (tuple, list)): + for i in g: + recursive_update(i) + if isinstance(g, dict): + for k, v in g.items(): + if isinstance(v, dict) and "id" in v: + g[k] = v["value"] + recursive_update(v) + + recursive_update(grouping) + return grouping @property @has_context diff --git a/dash/_grouping.py b/dash/_grouping.py index 77f87fa30e..984932d2a3 100644 --- a/dash/_grouping.py +++ b/dash/_grouping.py @@ -14,6 +14,7 @@ """ from dash.exceptions import InvalidCallbackReturnValue +from ._utils import AttributeDict def flatten_grouping(grouping, schema=None): @@ -123,14 +124,14 @@ def map_grouping(fn, grouping): return [map_grouping(fn, g) for g in grouping] if isinstance(grouping, dict): - return {k: map_grouping(fn, g) for k, g in grouping.items()} + return AttributeDict({k: map_grouping(fn, g) for k, g in grouping.items()}) return fn(grouping) def make_grouping_by_key(schema, source, default=None): """ - Create a grouping from a schema by ujsing the schema's scalar values to look up + Create a grouping from a schema by using the schema's scalar values to look up items in the provided source object. :param schema: A grouping of potential keys in source diff --git a/dash/_utils.py b/dash/_utils.py index 5f4bc1f8b0..aa0470f43d 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -119,6 +119,8 @@ def first(self, *names): value = self.get(name) if value: return value + if not names: + return next(iter(self), {}) def create_callback_id(output): @@ -152,7 +154,7 @@ def stringify_id(id_): def inputs_to_dict(inputs_list): - inputs = {} + inputs = AttributeDict() for i in inputs_list: inputsi = i if isinstance(i, list) else [i] for ii in inputsi: @@ -161,6 +163,16 @@ def inputs_to_dict(inputs_list): return inputs +def convert_to_AttributeDict(nested_list): + new_dict = [] + for i in nested_list: + if isinstance(i, dict): + new_dict.append(AttributeDict(i)) + else: + new_dict.append([AttributeDict(ii) for ii in i]) + return new_dict + + def inputs_to_vals(inputs): return [ [ii.get("value") for ii in i] if isinstance(i, list) else i.get("value") diff --git a/dash/_validate.py b/dash/_validate.py index 19a52e9b45..06b78bb870 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -134,6 +134,11 @@ def validate_and_group_input_args(flat_args, arg_index_grouping): if isinstance(arg_index_grouping, dict): func_args = [] func_kwargs = args_grouping + for key in func_kwargs: + if not key.isidentifier(): + raise exceptions.CallbackException( + f"{key} is not a valid Python variable name" + ) elif isinstance(arg_index_grouping, (tuple, list)): func_args = list(args_grouping) func_kwargs = {} diff --git a/dash/dash.py b/dash/dash.py index 66bb58e4fc..b23367866f 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -49,6 +49,7 @@ patch_collections_abc, split_callback_id, to_json, + convert_to_AttributeDict, gen_salt, ) from . import _callback @@ -1297,6 +1298,7 @@ def callback(_triggers, user_store_data, user_callback_args): def dispatch(self): body = flask.request.get_json() + flask.g.inputs_list = inputs = body.get( # pylint: disable=assigning-non-slot "inputs", [] ) @@ -1331,9 +1333,12 @@ def dispatch(self): # Add args_grouping inputs_state_indices = cb["inputs_state_indices"] inputs_state = inputs + state + inputs_state = convert_to_AttributeDict(inputs_state) + args_grouping = map_grouping( lambda ind: inputs_state[ind], inputs_state_indices ) + flask.g.args_grouping = args_grouping # pylint: disable=assigning-non-slot flask.g.using_args_grouping = ( # pylint: disable=assigning-non-slot not isinstance(inputs_state_indices, int) diff --git a/tests/integration/callbacks/test_callback_context.py b/tests/integration/callbacks/test_callback_context.py index e19c5e63ac..1080687a2e 100644 --- a/tests/integration/callbacks/test_callback_context.py +++ b/tests/integration/callbacks/test_callback_context.py @@ -2,7 +2,7 @@ import operator import pytest -from dash import Dash, Input, Output, html, dcc, callback_context +from dash import Dash, ALL, Input, Output, html, dcc, callback_context, ctx from dash.exceptions import PreventUpdate, MissingCallbackContextException import dash.testing.wait as wait @@ -330,3 +330,59 @@ def update_results(n1, n2, nsum): assert len(keys1) == 2 assert "sum-number.value" in keys1 assert "input-number-2.value" in keys1 + + +def test_cbcx007_triggered_id(dash_duo): + app = Dash(__name__) + + btns = ["btn-{}".format(x) for x in range(1, 6)] + + app.layout = html.Div( + [html.Div([html.Button(btn, id=btn) for btn in btns]), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), [Input(x, "n_clicks") for x in btns]) + def on_click(*args): + if not ctx.triggered: + raise PreventUpdate + for btn in btns: + if btn in ctx.triggered_prop_ids.values(): + assert btn == ctx.triggered_id + return f"Just clicked {btn}" + + dash_duo.start_server(app) + + for i in range(1, 5): + for btn in btns: + dash_duo.find_element("#" + btn).click() + dash_duo.wait_for_text_to_equal("#output", f"Just clicked {btn}") + + +def test_cbcx008_triggered_id_pmc(dash_duo): + + app = Dash() + app.layout = html.Div( + [ + html.Button("Click me", id={"type": "btn", "index": "myindex"}), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), Input({"type": "btn", "index": ALL}, "n_clicks") + ) + def func(n_clicks): + if ctx.triggered: + triggered_id, dict_id = next(iter(ctx.triggered_prop_ids.items())) + + assert dict_id == ctx.triggered_id + + if dict_id == {"type": "btn", "index": "myindex"}: + return dict_id["index"] + + dash_duo.start_server(app) + + dash_duo.find_element( + '#\\{\\"index\\"\\:\\"myindex\\"\\,\\"type\\"\\:\\"btn\\"\\}' + ).click() + dash_duo.wait_for_text_to_equal("#output", "myindex") diff --git a/tests/integration/callbacks/test_wildcards.py b/tests/integration/callbacks/test_wildcards.py index 7359c55671..4f5a403e12 100644 --- a/tests/integration/callbacks/test_wildcards.py +++ b/tests/integration/callbacks/test_wildcards.py @@ -1,6 +1,7 @@ import pytest import re from selenium.webdriver.common.keys import Keys +import json from dash.testing import wait import dash @@ -10,6 +11,12 @@ from tests.assets.grouping_app import grouping_app +def stringify_id(id_): + if isinstance(id_, dict): + return json.dumps(id_, sort_keys=True, separators=(",", ":")) + return id_ + + def css_escape(s): sel = re.sub("[\\{\\}\\\"\\'.:,]", lambda m: "\\" + m.group(0), s) print(sel) @@ -413,14 +420,38 @@ def assert_callback_context(items_text): args_grouping = dict( items=dict( all=[ - {"id": {"item": i}, "property": "children", "value": text} + { + "id": {"item": i}, + "property": "children", + "value": text, + "str_id": stringify_id({"item": i}), + "triggered": False, + } for i, text in enumerate(items_text[:-1]) ], - new=dict(id="new-item", property="value", value=items_text[-1]), + new=dict( + id="new-item", + property="value", + value=items_text[-1], + str_id="new-item", + triggered=False, + ), ), triggers=[ - {"id": "add", "property": "n_clicks", "value": len(items_text)}, - {"id": "new-item", "property": "n_submit"}, + { + "id": "add", + "property": "n_clicks", + "value": len(items_text), + "str_id": "add", + "triggered": True, + }, + { + "id": "new-item", + "property": "n_submit", + "value": None, + "str_id": "new-item", + "triggered": False, + }, ], ) dash_duo.wait_for_text_to_equal("#cc-args-grouping", repr(args_grouping))