diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index 37950020c2..a9aa08c280 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -97,17 +97,20 @@ logger = tb_logging.get_logger() -def tensor_size_guidance_from_flags(flags): - """Apply user per-summary size guidance overrides.""" - - tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) +def _parse_samples_per_plugin(flags): + result = {} if not flags or not flags.samples_per_plugin: - return tensor_size_guidance - + return result for token in flags.samples_per_plugin.split(","): k, v = token.strip().split("=") - tensor_size_guidance[k] = int(v) + result[k] = int(v) + return result + +def _apply_tensor_size_guidance(sampling_hints): + """Apply user per-summary size guidance overrides.""" + tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) + tensor_size_guidance.update(sampling_hints) return tensor_size_guidance @@ -151,9 +154,10 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider): multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider) else: # Regular logdir loading mode. + sampling_hints = _parse_samples_per_plugin(flags) multiplexer = event_multiplexer.EventMultiplexer( size_guidance=DEFAULT_SIZE_GUIDANCE, - tensor_size_guidance=tensor_size_guidance_from_flags(flags), + tensor_size_guidance=_apply_tensor_size_guidance(sampling_hints), purge_orphaned_data=flags.purge_orphaned_data, max_reload_threads=flags.max_reload_threads, event_file_active_filter=_get_event_file_active_filter(flags), @@ -238,6 +242,7 @@ def TensorBoardWSGIApp( multiplexer=deprecated_multiplexer, assets_zip_provider=assets_zip_provider, plugin_name_to_instance=plugin_name_to_instance, + sampling_hints=_parse_samples_per_plugin(flags), window_title=flags.window_title, ) tbplugins = [] diff --git a/tensorboard/plugins/base_plugin.py b/tensorboard/plugins/base_plugin.py index 339ca9cb60..51134052ab 100644 --- a/tensorboard/plugins/base_plugin.py +++ b/tensorboard/plugins/base_plugin.py @@ -254,6 +254,7 @@ def __init__( logdir=None, multiplexer=None, plugin_name_to_instance=None, + sampling_hints=None, window_title=None, ): """Instantiates magic container. @@ -291,6 +292,10 @@ def __init__( plugin may be absent from this mapping until it is registered. Plugin logic should handle cases in which a plugin is absent from this mapping, lest a KeyError is raised. + sampling_hints: Map from plugin name to `int` or `NoneType`, where + the value represents the user-specified downsampling limit as + given to the `--samples_per_plugin` flag, or `None` if none was + explicitly given for this plugin. window_title: A string specifying the window title. """ self.assets_zip_provider = assets_zip_provider @@ -301,6 +306,7 @@ def __init__( self.logdir = logdir self.multiplexer = multiplexer self.plugin_name_to_instance = plugin_name_to_instance + self.sampling_hints = sampling_hints self.window_title = window_title diff --git a/tensorboard/plugins/histogram/histograms_plugin.py b/tensorboard/plugins/histogram/histograms_plugin.py index b4e48e1fca..0890084529 100644 --- a/tensorboard/plugins/histogram/histograms_plugin.py +++ b/tensorboard/plugins/histogram/histograms_plugin.py @@ -39,6 +39,9 @@ from tensorboard.util import tensor_util +_DEFAULT_DOWNSAMPLING = 500 # histograms per time series + + class HistogramsPlugin(base_plugin.TBPlugin): """Histograms Plugin for TensorBoard. @@ -62,6 +65,9 @@ def __init__(self, context): """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data == "true": self._data_provider = context.data_provider else: @@ -174,20 +180,21 @@ def histograms_impl(self, tag, run, experiment, downsample_to=None): """Result of the form `(body, mime_type)`. At most `downsample_to` events will be returned. If this value is - `None`, then no downsampling will be performed. + `None`, then default downsampling will be performed. Raises: tensorboard.errors.PublicError: On invalid request. """ if self._data_provider: - # Downsample reads to 500 histograms per time series, which is - # the default size guidance for histograms under the multiplexer - # loading logic. - SAMPLE_COUNT = downsample_to if downsample_to is not None else 500 + sample_count = ( + downsample_to + if downsample_to is not None + else self._downsample_to + ) all_histograms = self._data_provider.read_tensors( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, + downsample=sample_count, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) histograms = all_histograms.get(run, {}).get(tag, None) diff --git a/tensorboard/plugins/image/images_plugin.py b/tensorboard/plugins/image/images_plugin.py index 7458f4ded0..bc06ec4864 100644 --- a/tensorboard/plugins/image/images_plugin.py +++ b/tensorboard/plugins/image/images_plugin.py @@ -43,6 +43,7 @@ } _DEFAULT_IMAGE_MIMETYPE = "application/octet-stream" +_DEFAULT_DOWNSAMPLING = 10 # images per time series # Extend imghdr.tests to include svg. @@ -69,6 +70,9 @@ def __init__(self, context): """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data == "true": self._data_provider = context.data_provider else: @@ -239,14 +243,10 @@ def _image_response_for_run(self, experiment, run, tag, sample): parameters. """ if self._data_provider: - # Downsample reads to 10 images per time series, which is the - # default size guidance for images under the multiplexer loading - # logic. - SAMPLE_COUNT = 10 all_images = self._data_provider.read_blob_sequences( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, + downsample=self._downsample_to, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) images = all_images.get(run, {}).get(tag, None) diff --git a/tensorboard/plugins/scalar/scalars_plugin.py b/tensorboard/plugins/scalar/scalars_plugin.py index 8a1faf4839..59ab6d10cc 100644 --- a/tensorboard/plugins/scalar/scalars_plugin.py +++ b/tensorboard/plugins/scalar/scalars_plugin.py @@ -40,6 +40,9 @@ from tensorboard.util import tensor_util +_DEFAULT_DOWNSAMPLING = 1000 # scalars per time series + + class OutputFormat(object): """An enum used to list the valid output formats for API calls.""" @@ -60,6 +63,9 @@ def __init__(self, context): """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data != "false": self._data_provider = context.data_provider else: @@ -169,14 +175,10 @@ def index_impl(self, experiment=None): def scalars_impl(self, tag, run, experiment, output_format): """Result of the form `(body, mime_type)`.""" if self._data_provider: - # Downsample reads to 1000 scalars per time series, which is the - # default size guidance for scalars under the multiplexer loading - # logic. - SAMPLE_COUNT = 1000 all_scalars = self._data_provider.read_scalars( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, + downsample=self._downsample_to, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) scalars = all_scalars.get(run, {}).get(tag, None) diff --git a/tensorboard/plugins/text/text_plugin.py b/tensorboard/plugins/text/text_plugin.py index 248d76741e..6ff564973b 100644 --- a/tensorboard/plugins/text/text_plugin.py +++ b/tensorboard/plugins/text/text_plugin.py @@ -48,6 +48,8 @@ 2d tables are supported. Showing a 2d slice of the data instead.""" ) +_DEFAULT_DOWNSAMPLING = 100 # text tensors per time series + def make_table_row(contents, tag="td"): """Given an iterable of string contents, make a table row. @@ -212,6 +214,9 @@ def __init__(self, context): context: A base_plugin.TBContext instance. """ self._multiplexer = context.multiplexer + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data == "true": self._data_provider = context.data_provider else: @@ -261,7 +266,7 @@ def text_impl(self, run, tag, experiment): all_text = self._data_provider.read_tensors( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=100, + downsample=self._downsample_to, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) text = all_text.get(run, {}).get(tag, None)