torch.nn.utils.weight_norm¶
- 
torch.nn.utils.weight_norm(module, name='weight', dim=0)[source]¶
- Applies weight normalization to a parameter in the given module. - Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by - name(e.g.- 'weight') with two parameters: one specifying the magnitude (e.g.- 'weight_g') and one specifying the direction (e.g.- 'weight_v'). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every- forward()call.- By default, with - dim=0, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use- dim=None.- See https://arxiv.org/abs/1602.07868 - Parameters
- Returns
- The original module with the weight norm hook 
 - Example: - >>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20])