Shortcuts

torch.where

torch.where(condition, x, y) → Tensor

Return a tensor of elements selected from either x or y, depending on condition.

The operation is defined as:

outi={xiif conditioniyiotherwise\text{out}_i = \begin{cases} \text{x}_i & \text{if } \text{condition}_i \\ \text{y}_i & \text{otherwise} \\ \end{cases}

Note

The tensors condition, x, y must be broadcastable.

Note

Currently valid scalar and tensor combination are 1. Scalar of floating dtype and torch.double 2. Scalar of integral dtype and torch.long 3. Scalar of complex dtype and torch.complex128

Parameters
  • condition (BoolTensor) – When True (nonzero), yield x, otherwise yield y

  • x (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is True

  • y (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is False

Returns

A tensor of shape equal to the broadcasted shape of condition, x, y

Return type

Tensor

Example:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) → tuple of LongTensor

torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).

Note

See also torch.nonzero().

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources