The batch normalization primitive performs a forward or backward batch normalization operation on 0D, 2D, or 3D spatial data.
The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Naming Conventions.
\[ \dst(n, c, h, w) = \gamma(c) \cdot \frac{\src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} + \beta(c), \]
where
Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:
\[ \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\ \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2. \]
The backward propagation computes \(\)\diffsrc\f$(n, c, h, w) \(,
\) \(\diffgamma\)(c)^* \(, and \) \(\diffbeta\)(c)^* \(based on
\) \(\diffdst\)(n, c, h, w) \(, \) \(\src\)(n, c, h, w) \(, \)\mu(c) \(,
\)\sigma^2(c) \(, \)\gamma(c) ^* \(, and \)\beta(c) ^* \(.
The tensors marked with an asterisk are used only when the primitive is
configured to use \)\gamma(c) \( and \)\beta(c) \( (i.e.,
#dnnl_use_scaleshift is set).
@section autotoc_md107 Execution Arguments
Depending on the @ref dnnl_normalization_flags_t "flags" and
@ref dnnl_prop_kind_t "propagation kind", the batch normalization primitive
requires different inputs and outputs. For clarity, a summary is shown below.
<table class="markdownTable">
<tr class="markdownTableHead"> <th class="markdownTableHeadLeft"> \ilinebr </th> <th class="markdownTableHeadLeft"> #dnnl_forward_inference \ilinebr </th> <th class="markdownTableHeadLeft"> #dnnl_forward_training \ilinebr </th> <th class="markdownTableHeadLeft"> #dnnl_backward \ilinebr </th> <th class="markdownTableHeadLeft"> #dnnl_backward_data \ilinebr </th> </tr>
<tr class="markdownTableRowOdd"> <td class="markdownTableBodyLeft"> #dnnl_normalization_flags_none \ilinebr </td> <td class="markdownTableBodyLeft"> <em>Inputs</em>: \)\src\f$
Outputs: \(\dst\)
Inputs: \(\src\)
Outputs: \(\dst\), \(\mu\), \(\sigma^2\)
Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\)
Outputs: \(\diffsrc\)
Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\)
Outputs: \(\diffsrc\)
Inputs: \(\src\), \(\mu\), \(\sigma^2\)
Outputs: \(\dst\)
Inputs: \(\src\), \(\mu\), \(\sigma^2\)
Outputs: \(\dst\)
Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\)
Outputs: \(\diffsrc\)
Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\)
Outputs: \(\diffsrc\)
Inputs: \(\src\), \(\gamma\), \(\beta\)
Outputs: \(\dst\)
Inputs: \(\src\), \(\gamma\), \(\beta\)
Outputs: \(\dst\), \(\mu\), \(\sigma^2\)
Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\)
Outputs: \(\diffsrc\), \(\)\diffgamma\f$ \(, \) \(\diffbeta\) \( \ilinebr </td> <td class="markdownTableBodyLeft"> <em>Inputs</em>: \)\diffdst\f$, \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\)
Outputs: \(\diffsrc\)
dnnl_use_global_stats | dnnl_use_scaleshift
Inputs: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\)
Outputs: \(\dst\)
Inputs: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\)
Outputs: \(\dst\)
Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\)
Outputs: \(\diffsrc\), \(\)\diffgamma\f$ \(, \) \(\diffbeta\) \( \ilinebr </td> <td class="markdownTableBodyLeft"> <em>Inputs</em>: \)\diffdst\f$, \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\)
Outputs: \(\diffsrc\)
flags | dnnl_fuse_norm_relu
Inputs: same as with flags
Outputs: same as with flags
Inputs: same as with flags
Outputs: same as with flags, Workspace
Inputs: same as with flags, Workspace
Outputs: same as with flags
Inputs: same as with flags, Workspace
Outputs: same as with flags
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
| Primitive input/output | Execution argument index |
|---|---|
| \(\src\) | DNNL_ARG_SRC |
| \(\gamma, \beta\) | DNNL_ARG_SCALE_SHIFT |
| mean ( \(\mu\)) | DNNL_ARG_MEAN |
| variance ( \(\sigma\)) | DNNL_ARG_VARIANCE |
| \(\dst\) | DNNL_ARG_DST |
| workspace | DNNL_ARG_WORKSPACE |
| \(\diffdst\) | DNNL_ARG_DIFF_DST |
| \(\diffsrc\) | DNNL_ARG_DIFF_SRC |
\(\)\diffgamma\f$, \(\diffbeta\) \( \ilinebr </td> <td class="markdownTableBodyNone"> DNNL_ARG_DIFF_SCALE_SHIFT \ilinebr </td> </tr></table> @section autotoc_md108 Implementation Details @subsection autotoc_md109 General Notes 1. The different flavors of the primitive are partially controlled by the @p flags parameter that is passed to the operation descriptor initialization function (e.g., dnnl::batch_normalization_forward::desc::desc()). Multiple flags can be set using the bitwise OR operator (<tt>|</tt>). 2. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the #dnnl_use_global_stats flag. For the backward propagation, the mean and variance are always input parameters. 3. The memory format and data type for <tt>src</tt> and <tt>dst</tt> are assumed to be the same, and in the API they are typically referred to as <tt>data</tt> (e.g., see <tt>data_desc</tt> in dnnl::batch_normalization_forward::desc::desc()). The same is true for <tt>diff_src</tt> and <tt>diff_dst</tt>. The corresponding memory descriptors are referred to as <tt>diff_data_desc</tt>. 4. Both forward and backward propagation support in-place operations, meaning that <tt>src</tt> can be used as input and output for forward propagation, and <tt>diff_dst</tt> can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. 5. As mentioned above, the batch normalization primitive can be fused with ReLU activation even in the training mode. In this case, on the forward propagation the primitive has one additional output, <tt>workspace</tt>, that should be passed during the backward propagation. @subsection autotoc_md110 Data Type Support The operation supports the following combinations of data types: <table class="markdownTable"> <tr class="markdownTableHead"> <th class="markdownTableHeadLeft"> Propagation \ilinebr </th> <th class="markdownTableHeadLeft"> Source / Destination \ilinebr </th> <th class="markdownTableHeadLeft"> Mean / Variance / ScaleShift \ilinebr </th> </tr> <tr class="markdownTableRowOdd"> <td class="markdownTableBodyLeft"> forward / backward \ilinebr </td> <td class="markdownTableBodyLeft"> f32, bf16 \ilinebr </td> <td class="markdownTableBodyLeft"> f32 \ilinebr </td> </tr> <tr class="markdownTableRowEven"> <td class="markdownTableBodyLeft"> forward \ilinebr </td> <td class="markdownTableBodyLeft"> f16 \ilinebr </td> <td class="markdownTableBodyLeft"> f32 \ilinebr </td> </tr> <tr class="markdownTableRowOdd"> <td class="markdownTableBodyLeft"> forward \ilinebr </td> <td class="markdownTableBodyLeft"> s8 \ilinebr </td> <td class="markdownTableBodyLeft"> f32 \ilinebr </td> </tr></table> @warning There might be hardware- or implementation-specific restrictions. Check the @ref dg_bnorm_impl_limits "Implementation Limitations" section below. @subsection autotoc_md111 Data Representation @subsubsection autotoc_md112 Mean and Variance The mean (\)\mu\f$) and variance ( \(\sigma^2\)) are separate 1D tensors of size \(C\). The format of the corresponding memory object must be dnnl_x (dnnl_a). |
If used, the scale ( \(\gamma\)) and shift ( \(\beta\)) are combined in a single 2D tensor of shape \(2 \times C\).
The format of the corresponding memory object must be dnnl_nc (dnnl_ab).
Like other CNN primitives, the batch normalization primitive expects data to be \(N \times C \times SP_n \times \cdots \times SP_0\) tensor.
The batch normalization primitive is optimized for the following memory formats:
| Spatial | Logical tensor | Implementations optimized for memory formats |
|---|---|---|
| 0D | NC | dnnl_nc (dnnl_ab) |
| 2D | NCHW | dnnl_nchw (dnnl_abcd), dnnl_nhwc (dnnl_acdb), optimized^ |
| 3D | NCDHW | dnnl_ncdhw (dnnl_abcde), dnnl_ndhwc (dnnl_acdeb), optimized^ |
Here optimized^ means the format that comes out of any preceding compute-intensive primitive.
Post-ops and attributes enable you to modify the behavior of the batch normalization primitive by chaining certain operations after the batch normalization operation. The following post-ops are supported by batch normalization primitives:
| Propagation | Type | Operation | Description |
|---|---|---|---|
| forward | post-op | eltwise | Applies an Eltwise operation to the result (currently only dnnl_eltwise_relu algorithm is supported) |
| Engine | Name | Comments |
|---|---|---|
| CPU/GPU | Batch Normalization Primitive Example | This C++ API example demonstrates how to create and execute a Batch Normalization primitive in forward training propagation mode. Key optimizations included in this example:
|