torch.vmap¶
- 
torch.vmap(func, in_dims=0, out_dims=0)[source]¶
- vmap is the vectorizing map. Returns a new function that maps func over some dimension of the inputs. Semantically, vmap pushes the map into PyTorch operations called by func, effectively vectorizing those operations. - vmap is useful for handling batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func). vmap can also be used to compute batched gradients when composed with autograd. - Warning - torch.vmap is an experimental prototype that is subject to change and/or deletion. Please use at your own risk. - Note - If you’re interested in using vmap for your use case, please contact us! We’re interested in gathering feedback from early adopters to inform the design. - Parameters
- func (function) – A Python function that takes one or more arguments. Must return one or more Tensors. 
- in_dims (int or nested structure) – Specifies which dimension of the inputs should be mapped over. in_dims should have a structure like the inputs. If the in_dim for a particular input is None, then that indicates there is no map dimension. Default: 0. 
- out_dims (int or Tuple[int]) – Specifies where the mapped dimension should appear in the outputs. If out_dims is a Tuple, then it should have one element per output. Default: 0. 
 
- Returns
- Returns a new “batched” function. It takes the same inputs as func, except each input has an extra dimension at the index specified by in_dims. It takes returns the same outputs as func, except each output has an extra dimension at the index specified by out_dims. 
 - One example of using vmap is to compute batched dot products. PyTorch doesn’t provide a batched torch.dot API; instead of unsuccessfully rummaging through docs, use vmap to construct a new function. - >>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) - vmap can be helpful in hiding batch dimensions, leading to a simpler model authoring experience. - >>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples) - vmap can also help vectorize computations that were previously difficult or impossible to batch. One example is higher-order gradient computation. The PyTorch autograd engine computes vjps (vector-Jacobian products). Computing a full Jacobian matrix for some function f: R^N -> R^N usually requires N calls to autograd.grad, one per Jacobian row. Using vmap, we can vectorize the whole computation, computing the Jacobian in a single call to autograd.grad. - >>> # Setup >>> N = 5 >>> f = lambda x: x ** 2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N) - Note - vmap does not provide general autobatching or handle variable-length sequences out of the box.