RandomStructured¶
- 
class torch.nn.utils.prune.RandomStructured(amount, dim=- 1)[source]¶
- Prune entire (currently unpruned) channels in a tensor at random. - Parameters
- amount (int or float) – quantity of parameters to prune. If - float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If- int, it represents the absolute number of parameters to prune.
- dim (int, optional) – index of the dim along which we define channels to prune. Default: -1. 
 
 - 
classmethod apply(module, name, amount, dim=- 1)[source]¶
- Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. - Parameters
- module (nn.Module) – module containing the tensor to prune 
- name (str) – parameter name within - moduleon which pruning will act.
- amount (int or float) – quantity of parameters to prune. If - float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If- int, it represents the absolute number of parameters to prune.
- dim (int, optional) – index of the dim along which we define channels to prune. Default: -1. 
 
 
 - 
apply_mask(module)¶
- Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor. - Parameters
- module (nn.Module) – module containing the tensor to prune 
- Returns
- pruned version of the input tensor 
- Return type
- pruned_tensor (torch.Tensor) 
 
 - 
compute_mask(t, default_mask)[source]¶
- Computes and returns a mask for the input tensor - t. Starting from a base- default_mask(which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the- default_maskby randomly zeroing out channels along the specified dim of the tensor.- Parameters
- t (torch.Tensor) – tensor representing the parameter to prune 
- default_mask (torch.Tensor) – Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as - t.
 
- Returns
- mask to apply to - t, of same dims as- t
- Return type
- mask (torch.Tensor) 
- Raises
- IndexError – if - self.dim >= len(t.shape)
 
 - 
prune(t, default_mask=None, importance_scores=None)¶
- Computes and returns a pruned version of input tensor - taccording to the pruning rule specified in- compute_mask().- Parameters
- t (torch.Tensor) – tensor to prune (of same dimensions as - default_mask).
- importance_scores (torch.Tensor) – tensor of importance scores (of same shape as - t) used to compute mask for pruning- t. The values in this tensor indicate the importance of the corresponding elements in the- tthat is being pruned. If unspecified or None, the tensor- twill be used in its place.
- default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones. 
 
- Returns
- pruned version of tensor - t.
 
 - 
remove(module)¶
- Removes the pruning reparameterization from a module. The pruned parameter named - nameremains permanently pruned, and the parameter named- name+'_orig'is removed from the parameter list. Similarly, the buffer named- name+'_mask'is removed from the buffers.- Note - Pruning itself is NOT undone or reversed!