Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Doc for Batchnorm, fix argument order, cleanup some comments #2046

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/sphinx/ngraph.doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES
# preprocessor.
# This tag requires that the tag SEARCH_INCLUDES is set to YES.

INCLUDE_PATH =
INCLUDE_PATH = ../../src

# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
# patterns (like *.h and *.hpp) to filter out the header-files in the
Expand Down
105 changes: 0 additions & 105 deletions doc/sphinx/source/ops/batch_norm.rst

This file was deleted.

80 changes: 80 additions & 0 deletions doc/sphinx/source/ops/batch_norm_inference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
.. batch_norm_inference.rst:

##################
BatchNormInference
##################

.. code-block:: cpp

BatchNormInference // Adjust input for mean and variance


Description
===========



Inputs
------

+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``variances`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+


Attributes
----------

+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+

Outputs
-------

+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+

Mathematical Definition
=======================

The axes of the input fall into two categories: positional and channel, with
channel being axis 1. For each position, there are :math:`C` channel values,
each normalized independently.

Normalization of a channel sample is controlled by two values:

* the `mean` :math:`\mu`, and

* the `variance` :math:`\sigma^2`;

and by two scaling attributes: :math:`\gamma` and :math:`\beta`.

.. math::

\mathtt{normalized}_{\bullet, c, \ldots} = \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c


C++ Interface
==============

.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:


89 changes: 89 additions & 0 deletions doc/sphinx/source/ops/batch_norm_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
.. batch_norm_training.rst:

#################
BatchNormTraining
#################

.. code-block:: cpp

BatchNormTraining // Compute mean and variance from the input.


Description
===========



Inputs
------

+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+


Attributes
----------

+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+

Outputs
-------

+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``batch_mean`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``batch_variance`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+

The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from
``input``.


Mathematical Definition
=======================

The axes of the input fall into two categories: positional and channel, with
channel being axis 1. For each position, there are :math:`C` channel values,
each normalized independently.

Normalization of a channel sample is controlled by two values:

* the `batch_mean` :math:`\mu`, and

* the `batch_variance` :math:`\sigma^2`;

and by two scaling attributes: :math:`\gamma` and :math:`\beta`.

The values for :math:`\mu` and :math:`\sigma^2` come from computing the
mean and variance of ``input``.

.. math::

\mu_c &= \mathop{\mathbb{E}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\sigma^2_c &= \mathop{\mathtt{Var}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\mathtt{normlized}_{\bullet, c, \ldots} &= \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c


C++ Interface
==============

.. doxygenclass:: ngraph::op::BatchNormTraining
:project: ngraph
:members:


71 changes: 71 additions & 0 deletions doc/sphinx/source/ops/batch_norm_training_backprop.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
.. batch_norm_training_backprop.rst:

#########################
BatchNormTrainingBackprop
#########################

.. code-block:: cpp

BatchNormTrainingBackprop // Compute mean and variance backprop from the input.


Description
===========



Inputs
------

+----------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+======================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+----------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``variance`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``normalized_delta`` | same as ``input`` | same as ``input`` |
+----------------------+-------------------------+------------------------------+


Attributes
----------

+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+

Outputs
-------

+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input_delta`` | same as ``input`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``gamma_delta`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``beta_delta`` | same as ``beta`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+


Mathematical Definition
=======================


C++ Interface
==============

.. doxygenclass:: ngraph::op::BatchNormTrainingBackprop
:project: ngraph
:members:


8 changes: 6 additions & 2 deletions doc/sphinx/source/ops/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ Not currently a comprehensive list.
* :doc:`atan`
* :doc:`avg_pool`
* :doc:`avg_pool_backprop`
* :doc:`batch_norm`
* :doc:`batch_norm_inference`
* :doc:`batch_norm_training`
* :doc:`batch_norm_training_backprop`
* :doc:`broadcast`
* :doc:`ceiling`
* :doc:`concat`
Expand Down Expand Up @@ -123,7 +125,9 @@ Not currently a comprehensive list.
atan.rst
avg_pool.rst
avg_pool_backprop.rst
batch_norm.rst
batch_norm_inference.rst
batch_norm_training.rst
batch_norm_training_backprop.rst
broadcast.rst
ceiling.rst
concat.rst
Expand Down
4 changes: 2 additions & 2 deletions python/ngraph/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,9 @@ def batch_norm(eps, # type: float
# type: (...) -> Node
"""Return batch normalization node."""
if mean is None and variance is None:
return BatchNormTraining(eps, gamma, beta, data)
return BatchNormTraining(data, gamma, beta, eps)
else:
return BatchNormInference(eps, gamma, beta, data, mean, variance)
return BatchNormInference(data, gamma, beta, mean, variance, eps)


@nameable_op
Expand Down
Loading