Shortcuts

InstanceNorm1d

class torch.nn.InstanceNorm1d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)[source]

Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization.

y=xE[x]Var[x]+ϵγ+βy = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch. γ\gamma and β\beta are learnable parameter vectors of size C (where C is the input size) if affine is True. The standard-deviation is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False).

By default, this layer uses instance statistics computed from input data in both training and evaluation modes.

If track_running_stats is set to True, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

Note

This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is x^new=(1momentum)×x^+momentum×xt\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t , where x^\hat{x} is the estimated statistic and xtx_t is the new observed value.

Note

InstanceNorm1d and LayerNorm are very similar, but have some subtle differences. InstanceNorm1d is applied on each channel of channeled data like multidimensional time series, but LayerNorm is usually applied on entire sample and often in NLP tasks. Additionally, LayerNorm applies elementwise affine transform, while InstanceNorm1d usually don’t apply affine transform.

Parameters
  • num_featuresCC from an expected input of size (N,C,L)(N, C, L) or LL from input of size (N,L)(N, L)

  • eps – a value added to the denominator for numerical stability. Default: 1e-5

  • momentum – the value used for the running_mean and running_var computation. Default: 0.1

  • affine – a boolean value that when set to True, this module has learnable affine parameters, initialized the same way as done for batch normalization. Default: False.

  • track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: False

Shape:
  • Input: (N,C,L)(N, C, L)

  • Output: (N,C,L)(N, C, L) (same shape as input)

Examples:

>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm1d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm1d(100, affine=True)
>>> input = torch.randn(20, 100, 40)
>>> output = m(input)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources