torch.tensor_split¶
-
torch.
tensor_split
(input, indices_or_sections, dim=0) → List of Tensors¶ Splits a tensor into multiple sub-tensors, all of which are views of
input
, along dimensiondim
according to the indices or number of sections specified byindices_or_sections
. This function is based on NumPy’snumpy.array_split()
.- Parameters
input (Tensor) – the tensor to split
indices_or_sections (Tensor, int or list or tuple of python:ints) –
If
indices_or_sections
is an integern
or a zero dimensional long tensor with valuen
,input
is split inton
sections along dimensiondim
. Ifinput
is divisible byn
along dimensiondim
, each section will be of equal size,input.size(dim) / n
. Ifinput
is not divisible byn
, the sizes of the firstint(input.size(dim) % n)
sections will have sizeint(input.size(dim) / n) + 1
, and the rest will have sizeint(input.size(dim) / n)
.If
indices_or_sections
is a list or tuple of ints, or a one-dimensional long tensor, theninput
is split along dimensiondim
at each of the indices in the list, tuple or tensor. For instance,indices_or_sections=[2, 3]
anddim=0
would result in the tensorsinput[:2]
,input[2:3]
, andinput[3:]
.If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU.
dim (int, optional) – dimension along which to split the tensor. Default:
0
- Example::
>>> x = torch.arange(8) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))
>>> x = torch.arange(7) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) >>> torch.tensor_split(x, (1, 6)) (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))
>>> x = torch.arange(14).reshape(2, 7) >>> x tensor([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]]) >>> torch.tensor_split(x, 3, dim=1) (tensor([[0, 1, 2], [7, 8, 9]]), tensor([[ 3, 4], [10, 11]]), tensor([[ 5, 6], [12, 13]])) >>> torch.tensor_split(x, (1, 6), dim=1) (tensor([[0], [7]]), tensor([[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]]))