@@ -112,11 +112,51 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
112112}
113113
114114/* *
115- * \brief {o_0, o_1} = calc(i_0)
115+ * \brief Normalization with across maps.
116116 *
117- * \param inputs[0] input value.
118- * \param outputs[0] output value.
119- * \param outputs[1] denoms.
117+ * This Function comes from the paper
118+ * "ImageNet Classification with Deep Convolutional Neural Networks".
119+ *
120+ * The original formula is:
121+ *
122+ * Input(i, x, y)
123+ * Output(i, x, y) = ----------------------------------------------
124+ * -- upper
125+ * (k + alpha * > (Input(j, x, y))^2) ^ (beta)
126+ * -- j = lower
127+ *
128+ * upper is `min(C, c + N/2)`
129+ * lower if `max(0, c - N/2)`
130+ *
131+ * Function implementation:
132+ *
133+ * inputs and outpus is NCHW format, while input.shape.ndims() is equal 4.
134+ * And the meaning of each dimension(0-3) is respectively batch size,
135+ * feature maps, rows and columns.
136+ *
137+ * Input and Output in the above formula is for each map(i) of one image, and
138+ * Input(i, x, y), Output(i, x, y) represents an element in an image.
139+ *
140+ * C is the number of feature maps of one image, and N is a hyper-parameters
141+ * is configured when Function is initialized. The sum in the denominator
142+ * is the sum of the same position in the neighboring maps.
143+ *
144+ * In the implementation of Function, k is equal to 1,
145+ * so Function has no argument for k.
146+ *
147+ * Function Arguments:
148+ *
149+ * \param size_ represent N
150+ * \param scale_ represent alpha
151+ * \param pow_ represent beta
152+ * \param inputs[0] represent Input
153+ * \param outputs[0] represent Output
154+ * \param outputs[1] represent The denominator in the formula(except beta)
155+ *
156+ * Note:
157+ * Save output[1] is to simplify the backward calculation.
158+ * TODO, if only consider the forward calculation, we can optimize to
159+ * remove the output[1].
120160 */
121161template <DeviceType Device>
122162class CrossMapNormalFunc : public FunctionBase {
@@ -161,13 +201,36 @@ class CrossMapNormalFunc : public FunctionBase {
161201};
162202
163203/* *
164- * \brief {o_0} = calc(i_0, i_1, i_2, i_3)
204+ * \brief Backward calculation for normalization with across maps.
205+ *
206+ * Function implementation:
207+ *
208+ * The implementation of this Function is derived from the
209+ * CrossMapNormalFunc implementation.
210+ *
211+ * InputGrad = OutputGrad * denoms ^ (-beta)
212+ * -- upper
213+ * + > (OutputGrad * OutputValue * (-2 * alpha * beta) / denoms) * InputValue
214+ * -- lower
215+ *
216+ * The data of inputs/outputs format is the same as the forward interface
217+ * and is NCHW.
218+ *
219+ * The upper and lower is the same as forward. The logic of the sum
220+ * is also the same as forward.
221+ *
222+ * Function Arguments:
165223 *
166- * \param inputs[0] input value.
167- * \param inputs[1] output value.
168- * \param inputs[2] output grad.
169- * \param inputs[3] denoms.
170- * \param outputs[0] input grad.
224+ * \param size_ represent N
225+ * \param scale_ represent alpha
226+ * \param pow_ represent beta
227+ * \param inputs[0] represent InputValue, inputs[0] of CrossMapNormalFunc
228+ * \param inputs[1] represent OutputValue, outputs[0] of CrossMapNormalFunc
229+ * \param inputs[2] represent OutputGrad
230+ * \param inputs[3] represent denoms, outputs[1] of CrossMapNormalFunc
231+ * This is the intermediate result that is
232+ * preserved in the forward calculation.
233+ * \param outputs[0] represent InputGrad
171234 */
172235template <DeviceType Device>
173236class CrossMapNormalGradFunc : public FunctionBase {
0 commit comments