torch.sparse¶
Introduction¶
PyTorch provides torch.Tensor
to represent a
multi-dimensional array containing elements of a single data type. By
default, array elements are stored contiguously in memory leading to
efficient implementations of various array processing algorithms that
relay on the fast access to array elements. However, there exists an
important class of multi-dimensional arrays, so-called sparse arrays,
where the contiguous memory storage of array elements turns out to be
suboptimal. Sparse arrays have a property of having a vast portion of
elements being equal to zero which means that a lot of memory as well
as processor resources can be spared if only the non-zero elements are
stored or/and processed. Various sparse storage formats (such as COO,
CSR/CSC, LIL, etc.) have been developed that are optimized for a
particular structure of non-zero elements in sparse arrays as well as
for specific operations on the arrays.
Note
When talking about storing only non-zero elements of a sparse array, the usage of adjective “non-zero” is not strict: one is allowed to store also zeros in the sparse array data structure. Hence, in the following, we use “specified elements” for those array elements that are actually stored. In addition, the unspecified elements are typically assumed to have zero value, but not only, hence we use the term “fill value” to denote such elements.
Note
Using a sparse storage format for storing sparse arrays can be advantageous only when the size and sparsity levels of arrays are high. Otherwise, for small-sized or low-sparsity arrays using the contiguous memory storage format is likely the most efficient approach.
Warning
The PyTorch API of sparse tensors is in beta and may change in the near future.
Sparse COO tensors¶
Currently, PyTorch implements the so-called Coordinate format, or COO format, as the default sparse storage format for storing sparse tensors. In COO format, the specified elements are stored as tuples of element indices and the corresponding values. In particular,
the indices of specified elements are collected in
indices
tensor of size(ndim, nse)
and with element typetorch.int64
,the corresponding values are collected in
values
tensor of size(nse,)
and with an arbitrary integer or floating point number element type,
where ndim
is the dimensionality of the tensor and nse
is the
number of specified elements.
Note
The memory consumption of a sparse COO tensor is at least (ndim *
8 + <size of element type in bytes>) * nse
bytes (plus a constant
overhead from storing other tensor data).
The memory consumption of a strided tensor is at least
product(<tensor shape>) * <size of element type in bytes>
.
For example, the memory consumption of a 10 000 x 10 000 tensor
with 100 000 non-zero 32-bit floating point numbers is at least
(2 * 8 + 4) * 100 000 = 2 000 000
bytes when using COO tensor
layout and 10 000 * 10 000 * 4 = 400 000 000
bytes when using
the default strided tensor layout. Notice the 200 fold memory
saving from using the COO storage format.
Construction¶
A sparse COO tensor can be constructed by providing the two tensors of
indices and values, as well as the size of the sparse tensor (when it
cannot be inferred from the indices and values tensors) to a function
torch.sparse_coo_tensor()
.
Suppose we want to define a sparse tensor with the entry 3 at location (0, 2), entry 4 at location (1, 0), and entry 5 at location (1, 2). Unspecified elements are assumed to have the same value, fill value, which is zero by default. We would then write:
>>> i = [[0, 1, 1],
[2, 0, 2]]
>>> v = [3, 4, 5]
>>> s = torch.sparse_coo_tensor(i, v, (2, 3))
>>> s
tensor(indices=tensor([[0, 1, 1],
[2, 0, 2]]),
values=tensor([3, 4, 5]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
>>> s.to_dense()
tensor([[0, 0, 3],
[4, 0, 5]])
Note that the input i
is NOT a list of index tuples. If you want
to write your indices this way, you should transpose before passing them to
the sparse constructor:
>>> i = [[0, 2], [1, 0], [1, 2]]
>>> v = [3, 4, 5 ]
>>> s = torch.sparse_coo_tensor(list(zip(*i)), v, (2, 3))
>>> # Or another equivalent formulation to get s
>>> s = torch.sparse_coo_tensor(torch.tensor(i).t(), v, (2, 3))
>>> torch.sparse_coo_tensor(i.t(), v, torch.Size([2,3])).to_dense()
tensor([[0, 0, 3],
[4, 0, 5]])
An empty sparse COO tensor can be constructed by specifying its size only:
>>> torch.sparse_coo_tensor(size=(2, 3))
tensor(indices=tensor([], size=(2, 0)),
values=tensor([], size=(0,)),
size=(2, 3), nnz=0, layout=torch.sparse_coo)
Hybrid sparse COO tensors¶
Pytorch implements an extension of sparse tensors with scalar values to sparse tensors with (contiguous) tensor values. Such tensors are called hybrid tensors.
PyTorch hybrid COO tensor extends the sparse COO tensor by allowing
the values
tensor to be a multi-dimensional tensor so that we
have:
the indices of specified elements are collected in
indices
tensor of size(sparse_dims, nse)
and with element typetorch.int64
,the corresponding (tensor) values are collected in
values
tensor of size(nse, dense_dims)
and with an arbitrary integer or floating point number element type.
Note
We use (M + K)-dimensional tensor to denote a N-dimensional hybrid sparse tensor, where M and K are the numbers of sparse and dense dimensions, respectively, such that M + K == N holds.
Suppose we want to create a (2 + 1)-dimensional tensor with the entry [3, 4] at location (0, 2), entry [5, 6] at location (1, 0), and entry [7, 8] at location (1, 2). We would write
>>> i = [[0, 1, 1],
[2, 0, 2]]
>>> v = [[3, 4], [5, 6], [7, 8]]
>>> s = torch.sparse_coo_tensor(i, v, (2, 3, 2))
>>> s
tensor(indices=tensor([[0, 1, 1],
[2, 0, 2]]),
values=tensor([[3, 4],
[5, 6],
[7, 8]]),
size=(2, 3, 2), nnz=3, layout=torch.sparse_coo)
>>> s.to_dense()
tensor([[[0, 0],
[0, 0],
[3, 4]],
[[5, 6],
[0, 0],
[7, 8]]])
In general, if s
is a sparse COO tensor and M =
s.sparse_dim()
, K = s.dense_dim()
, then we have the following
invariants:
M + K == len(s.shape) == s.ndim
- dimensionality of a tensor is the sum of the number of sparse and dense dimensions,
s.indices().shape == (M, nse)
- sparse indices are stored explicitly,
s.values().shape == (nse,) + s.shape[M : M + K]
- the values of a hybrid tensor are K-dimensional tensors,
s.values().layout == torch.strided
- values are stored as strided tensors.
Note
Dense dimensions always follow sparse dimensions, that is, mixing of dense and sparse dimensions is not supported.
Uncoalesced sparse COO tensors¶
PyTorch sparse COO tensor format permits uncoalesced sparse tensors,
where there may be duplicate coordinates in the indices; in this case,
the interpretation is that the value at that index is the sum of all
duplicate value entries. For example, one can specify multiple values,
3
and 4
, for the same index 1
, that leads to an 1-D
uncoalesced tensor:
>>> i = [[1, 1]]
>>> v = [3, 4]
>>> s=torch.sparse_coo_tensor(i, v, (3,))
>>> s
tensor(indices=tensor([[1, 1]]),
values=tensor( [3, 4]),
size=(3,), nnz=2, layout=torch.sparse_coo)
while the coalescing process will accumulate the multi-valued elements into a single value using summation:
>>> s.coalesce()
tensor(indices=tensor([[1]]),
values=tensor([7]),
size=(3,), nnz=1, layout=torch.sparse_coo)
In general, the output of torch.Tensor.coalesce()
method is a
sparse tensor with the following properties:
the indices of specified tensor elements are unique,
the indices are sorted in lexicographical order,
torch.Tensor.is_coalesced()
returnsTrue
.
Note
For the most part, you shouldn’t have to care whether or not a sparse tensor is coalesced or not, as most operations will work identically given a coalesced or uncoalesced sparse tensor.
However, some operations can be implemented more efficiently on uncoalesced tensors, and some on coalesced tensors.
For instance, addition of sparse COO tensors is implemented by simply concatenating the indices and values tensors:
>>> a = torch.sparse_coo_tensor([[1, 1]], [5, 6], (2,))
>>> b = torch.sparse_coo_tensor([[0, 0]], [7, 8], (2,))
>>> a + b
tensor(indices=tensor([[0, 0, 1, 1]]),
values=tensor([7, 8, 5, 6]),
size=(2,), nnz=4, layout=torch.sparse_coo)
If you repeatedly perform an operation that can produce duplicate
entries (e.g., torch.Tensor.add()
), you should occasionally
coalesce your sparse tensors to prevent them from growing too large.
On the other hand, the lexicographical ordering of indices can be advantageous for implementing algorithms that involve many element selection operations, such as slicing or matrix products.
Working with sparse COO tensors¶
Let’s consider the following example:
>>> i = [[0, 1, 1],
[2, 0, 2]]
>>> v = [[3, 4], [5, 6], [7, 8]]
>>> s = torch.sparse_coo_tensor(i, v, (2, 3, 2))
As mentioned above, a sparse COO tensor is a torch.Tensor
instance and to distinguish it from the Tensor instances that use
some other layout, on can use torch.Tensor.is_sparse
or
torch.Tensor.layout
properties:
>>> isinstance(s, torch.Tensor)
True
>>> s.is_sparse
True
>>> s.layout == torch.sparse_coo
True
The number of sparse and dense dimensions can be acquired using
methods torch.Tensor.sparse_dim()
and
torch.Tensor.dense_dim()
, respectively. For instance:
>>> s.sparse_dim(), s.dense_dim()
(2, 1)
If s
is a sparse COO tensor then its COO format data can be
acquired using methods torch.Tensor.indices()
and
torch.Tensor.values()
.
Note
Currently, one can acquire the COO format data only when the tensor instance is coalesced:
>>> s.indices()
RuntimeError: Cannot get indices on an uncoalesced tensor, please call .coalesce() first
For acquiring the COO format data of an uncoalesced tensor, use
torch.Tensor._values()
and torch.Tensor._indices()
:
>>> s._indices()
tensor([[0, 1, 1],
[2, 0, 2]])
Constructing a new sparse COO tensor results a tensor that is not coalesced:
>>> s.is_coalesced()
False
but one can construct a coalesced copy of a sparse COO tensor using
the torch.Tensor.coalesce()
method:
>>> s2 = s.coalesce()
>>> s2.indices()
tensor([[0, 1, 1],
[2, 0, 2]])
When working with uncoalesced sparse COO tensors, one must take into
an account the additive nature of uncoalesced data: the values of the
same indices are the terms of a sum that evaluation gives the value of
the corresponding tensor element. For example, the scalar
multiplication on an uncoalesced sparse tensor could be implemented by
multiplying all the uncoalesced values with the scalar because c *
(a + b) == c * a + c * b
holds. However, any nonlinear operation,
say, a square root, cannot be implemented by applying the operation to
uncoalesced data because sqrt(a + b) == sqrt(a) + sqrt(b)
does not
hold in general.
Slicing (with positive step) of a sparse COO tensor is supported only for dense dimensions. Indexing is supported for both sparse and dense dimensions:
>>> s[1]
tensor(indices=tensor([[0, 2]]),
values=tensor([[5, 6],
[7, 8]]),
size=(3, 2), nnz=2, layout=torch.sparse_coo)
>>> s[1, 0, 1]
tensor(6)
>>> s[1, 0, 1:]
tensor([6])
In PyTorch, the fill value of a sparse tensor cannot be specified
explicitly and is assumed to be zero in general. However, there exists
operations that may interpret the fill value differently. For
instance, torch.sparse.softmax()
computes the softmax with the
assumption that the fill value is negative infinity.
Supported Linear Algebra operations¶
The following table summarizes supported Linear Algebra operations on
sparse matrices where the operands layouts may vary. Here
T[layout]
denotes a tensor with a given layout. Similarly,
M[layout]
denotes a matrix (2-D PyTorch tensor), and V[layout]
denotes a vector (1-D PyTorch tensor). In addition, f
denotes a
scalar (float or 0-D PyTorch tensor), *
is element-wise
multiplication, and @
is matrix multiplication.
PyTorch operation |
Sparse grad? |
Layout signature |
---|---|---|
no |
|
|
no |
|
|
no |
|
|
yes |
|
|
no |
|
|
no |
|
|
no |
|
|
no |
|
|
yes |
|
|
no |
|
|
no |
|
|
yes |
|
|
yes |
|
where “Sparse grad?” column indicates if the PyTorch operation supports
backward with respect to sparse matrix argument. All PyTorch operations,
except torch.smm()
, support backward with respect to strided
matrix arguments.
Note
Currently, PyTorch does not support matrix multiplication with the
layout signature M[strided] @ M[sparse_coo]
. However,
applications can still compute this using the matrix relation D @
S == (S.t() @ D.t()).t()
.
-
class
torch.
Tensor
The following methods are specific to sparse tensors:
-
is_sparse
¶ Is
True
if the Tensor uses sparse storage layout,False
otherwise.
-
dense_dim
() → int¶ Return the number of dense dimensions in a sparse tensor
self
.Warning
Throws an error if
self
is not a sparse tensor.See also
Tensor.sparse_dim()
and hybrid tensors.
-
sparse_dim
() → int¶ Return the number of sparse dimensions in a sparse tensor
self
.Warning
Throws an error if
self
is not a sparse tensor.See also
Tensor.dense_dim()
and hybrid tensors.
-
sparse_mask
(mask) → Tensor¶ Returns a new sparse tensor with values from a strided tensor
self
filtered by the indices of the sparse tensormask
. The values ofmask
sparse tensor are ignored.self
andmask
tensors must have the same shape.Note
The returned sparse tensor has the same indices as the sparse tensor
mask
, even when the corresponding values inself
are zeros.- Parameters
mask (Tensor) – a sparse tensor whose indices are used as a filter
Example:
>>> nse = 5 >>> dims = (5, 5, 2, 2) >>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)), ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) >>> V = torch.randn(nse, dims[2], dims[3]) >>> S = torch.sparse_coo_tensor(I, V, dims).coalesce() >>> D = torch.randn(dims) >>> D.sparse_mask(S) tensor(indices=tensor([[0, 0, 0, 2], [0, 1, 4, 3]]), values=tensor([[[ 1.6550, 0.2397], [-0.1611, -0.0779]], [[ 0.2326, -1.0558], [ 1.4711, 1.9678]], [[-0.5138, -0.0411], [ 1.9417, 0.5158]], [[ 0.0793, 0.0036], [-0.2569, -0.1055]]]), size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo)
-
sparse_resize_
(size, sparse_dim, dense_dim) → Tensor¶ Resizes
self
sparse tensor to the desired size and the number of sparse and dense dimensions.Note
If the number of specified elements in
self
is zero, thensize
,sparse_dim
, anddense_dim
can be any size and positive integers such thatlen(size) == sparse_dim + dense_dim
.If
self
specifies one or more elements, however, then each dimension insize
must not be smaller than the corresponding dimension ofself
,sparse_dim
must equal the number of sparse dimensions inself
, anddense_dim
must equal the number of dense dimensions inself
.Warning
Throws an error if
self
is not a sparse tensor.
-
sparse_resize_and_clear_
(size, sparse_dim, dense_dim) → Tensor¶ Removes all specified elements from a sparse tensor
self
and resizesself
to the desired size and the number of sparse and dense dimensions.
-
to_dense
() → Tensor¶ Creates a strided copy of
self
.Warning
Throws an error if
self
is a strided tensor.Example:
>>> s = torch.sparse_coo_tensor( ... torch.tensor([[1, 1], ... [0, 2]]), ... torch.tensor([9, 10]), ... size=(3, 3)) >>> s.to_dense() tensor([[ 0, 0, 0], [ 9, 0, 10], [ 0, 0, 0]])
-
to_sparse
(sparseDims) → Tensor¶ Returns a sparse copy of the tensor. PyTorch supports sparse tensors in coordinate format.
- Parameters
sparseDims (int, optional) – the number of sparse dimensions to include in the new sparse tensor
Example:
>>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) >>> d tensor([[ 0, 0, 0], [ 9, 0, 10], [ 0, 0, 0]]) >>> d.to_sparse() tensor(indices=tensor([[1, 1], [0, 2]]), values=tensor([ 9, 10]), size=(3, 3), nnz=2, layout=torch.sparse_coo) >>> d.to_sparse(1) tensor(indices=tensor([[1]]), values=tensor([[ 9, 0, 10]]), size=(3, 3), nnz=1, layout=torch.sparse_coo)
-
coalesce
() → Tensor¶ Returns a coalesced copy of
self
ifself
is an uncoalesced tensor.Returns
self
ifself
is a coalesced tensor.Warning
Throws an error if
self
is not a sparse COO tensor.
-
is_coalesced
() → bool¶ Returns
True
ifself
is a sparse COO tensor that is coalesced,False
otherwise.Warning
Throws an error if
self
is not a sparse COO tensor.See
coalesce()
and uncoalesced tensors.
-
indices
() → Tensor¶ Return the indices tensor of a sparse COO tensor.
Warning
Throws an error if
self
is not a sparse COO tensor.See also
Tensor.values()
.Note
This method can only be called on a coalesced sparse tensor. See
Tensor.coalesce()
for details.
-
values
() → Tensor¶ Return the values tensor of a sparse COO tensor.
Warning
Throws an error if
self
is not a sparse COO tensor.See also
Tensor.indices()
.Note
This method can only be called on a coalesced sparse tensor. See
Tensor.coalesce()
for details.
-
The following torch.Tensor
methods support sparse COO
tensors:
add()
add_()
addmm()
addmm_()
any()
asin()
asin_()
arcsin()
arcsin_()
bmm()
clone()
deg2rad()
deg2rad_()
detach()
detach_()
dim()
div()
div_()
floor_divide()
floor_divide_()
get_device()
index_select()
isnan()
log1p()
log1p_()
mm()
mul()
mul_()
mv()
narrow_copy()
neg()
neg_()
negative()
negative_()
numel()
rad2deg()
rad2deg_()
resize_as_()
size()
pow()
sqrt()
square()
smm()
sspaddmm()
sub()
sub_()
t()
t_()
transpose()
transpose_()
zero_()
Sparse tensor functions¶
-
torch.
sparse_coo_tensor
(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) → Tensor¶ Constructs a sparse tensor in COO(rdinate) format with specified values at the given
indices
.Note
This function returns an uncoalesced tensor.
- Parameters
indices (array_like) – Initial data for the tensor. Can be a list, tuple, NumPy
ndarray
, scalar, and other types. Will be cast to atorch.LongTensor
internally. The indices are the coordinates of the non-zero values in the matrix, and thus should be two-dimensional where the first dimension is the number of tensor dimensions and the second dimension is the number of non-zero values.values (array_like) – Initial values for the tensor. Can be a list, tuple, NumPy
ndarray
, scalar, and other types.size (list, tuple, or
torch.Size
, optional) – Size of the sparse tensor. If not provided the size will be inferred as the minimum size big enough to hold all non-zero elements.
- Keyword Arguments
dtype (
torch.dtype
, optional) – the desired data type of returned tensor. Default: if None, infers data type fromvalues
.device (
torch.device
, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (seetorch.set_default_tensor_type()
).device
will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default:
False
.
Example:
>>> i = torch.tensor([[0, 1, 1], ... [2, 0, 2]]) >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) >>> torch.sparse_coo_tensor(i, v, [2, 4]) tensor(indices=tensor([[0, 1, 1], [2, 0, 2]]), values=tensor([3., 4., 5.]), size=(2, 4), nnz=3, layout=torch.sparse_coo) >>> torch.sparse_coo_tensor(i, v) # Shape inference tensor(indices=tensor([[0, 1, 1], [2, 0, 2]]), values=tensor([3., 4., 5.]), size=(2, 3), nnz=3, layout=torch.sparse_coo) >>> torch.sparse_coo_tensor(i, v, [2, 4], ... dtype=torch.float64, ... device=torch.device('cuda:0')) tensor(indices=tensor([[0, 1, 1], [2, 0, 2]]), values=tensor([3., 4., 5.]), device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, layout=torch.sparse_coo) # Create an empty sparse tensor with the following invariants: # 1. sparse_dim + dense_dim = len(SparseTensor.shape) # 2. SparseTensor._indices().shape = (sparse_dim, nnz) # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) # # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) tensor(indices=tensor([], size=(1, 0)), values=tensor([], size=(0,)), size=(1,), nnz=0, layout=torch.sparse_coo) # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and # sparse_dim = 1 >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) tensor(indices=tensor([], size=(1, 0)), values=tensor([], size=(0, 2)), size=(1, 2), nnz=0, layout=torch.sparse_coo)
-
torch.sparse.
sum
(input, dim=None, dtype=None)[source]¶ Returns the sum of each row of the sparse tensor
input
in the given dimensionsdim
. Ifdim
is a list of dimensions, reduce over all of them. When sum over allsparse_dim
, this method returns a dense tensor instead of a sparse tensor.All summed
dim
are squeezed (seetorch.squeeze()
), resulting an output tensor havingdim
fewer dimensions thaninput
.During backward, only gradients at
nnz
locations ofinput
will propagate back. Note that the gradients ofinput
is coalesced.- Parameters
Example:
>>> nnz = 3 >>> dims = [5, 5, 2, 3] >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) >>> V = torch.randn(nnz, dims[2], dims[3]) >>> size = torch.Size(dims) >>> S = torch.sparse_coo_tensor(I, V, size) >>> S tensor(indices=tensor([[2, 0, 3], [2, 4, 1]]), values=tensor([[[-0.6438, -1.6467, 1.4004], [ 0.3411, 0.0918, -0.2312]], [[ 0.5348, 0.0634, -2.0494], [-0.7125, -1.0646, 2.1844]], [[ 0.1276, 0.1874, -0.6334], [-1.9682, -0.5340, 0.7483]]]), size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) # when sum over only part of sparse_dims, return a sparse tensor >>> torch.sparse.sum(S, [1, 3]) tensor(indices=tensor([[0, 2, 3]]), values=tensor([[-1.4512, 0.4073], [-0.8901, 0.2017], [-0.3183, -1.7539]]), size=(5, 2), nnz=3, layout=torch.sparse_coo) # when sum over all sparse dim, return a dense tensor # with summed dims squeezed >>> torch.sparse.sum(S, [0, 1, 3]) tensor([-2.6596, -1.1450])
-
torch.sparse.
addmm
(mat, mat1, mat2, beta=1.0, alpha=1.0)[source]¶ This function does exact same thing as
torch.addmm()
in the forward, except that it supports backward for sparse matrixmat1
.mat1
need to have sparse_dim = 2. Note that the gradients ofmat1
is a coalesced sparse tensor.
-
torch.sparse.
mm
(mat1, mat2)[source]¶ Performs a matrix multiplication of the sparse matrix
mat1
and the (sparse or strided) matrixmat2
. Similar totorch.mm()
, Ifmat1
is a tensor,mat2
is a tensor, out will be a tensor.mat1
need to have sparse_dim = 2. This function also supports backward for both matrices. Note that the gradients ofmat1
is a coalesced sparse tensor.- Parameters
mat1 (SparseTensor) – the first sparse matrix to be multiplied
mat2 (Tensor) – the second matrix to be multiplied, which could be sparse or dense
- Shape:
The format of the output tensor of this function follows: - sparse x sparse -> sparse - sparse x dense -> dense
Example:
>>> a = torch.randn(2, 3).to_sparse().requires_grad_(True) >>> a tensor(indices=tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]), values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]), size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True) >>> b = torch.randn(3, 2, requires_grad=True) >>> b tensor([[-0.6479, 0.7874], [-1.2056, 0.5641], [-1.1716, -0.9923]], requires_grad=True) >>> y = torch.sparse.mm(a, b) >>> y tensor([[-0.3323, 1.8723], [-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>) >>> y.sum().backward() >>> a.grad tensor(indices=tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]), values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]), size=(2, 3), nnz=6, layout=torch.sparse_coo)
-
torch.
sspaddmm
(input, mat1, mat2, *, beta=1, alpha=1, out=None) → Tensor¶ Matrix multiplies a sparse tensor
mat1
with a dense tensormat2
, then adds the sparse tensorinput
to the result.Note: This function is equivalent to
torch.addmm()
, exceptinput
andmat1
are sparse.- Parameters
- Keyword Arguments
beta (Number, optional) – multiplier for
mat
( )alpha (Number, optional) – multiplier for ( )
out (Tensor, optional) – the output tensor.
-
torch.
hspmm
(mat1, mat2, *, out=None) → Tensor¶ Performs a matrix multiplication of a sparse COO matrix
mat1
and a strided matrixmat2
. The result is a (1 + 1)-dimensional hybrid COO matrix.
-
torch.
smm
(input, mat) → Tensor¶ Performs a matrix multiplication of the sparse matrix
input
with the dense matrixmat
.
-
torch.sparse.
softmax
(input, dim, dtype=None)[source]¶ Applies a softmax function.
Softmax is defined as:
where run over sparse tensor indices and unspecified entries are ignores. This is equivalent to defining unspecified entries as negative infinity so that when the entry with index has not specified.
It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.
- Parameters
input (Tensor) – input
dim (int) – A dimension along which softmax will be computed.
dtype (
torch.dtype
, optional) – the desired data type of returned tensor. If specified, the input tensor is casted todtype
before the operation is performed. This is useful for preventing data type overflows. Default: None
-
torch.sparse.
log_softmax
(input, dim, dtype=None)[source]¶ Applies a softmax function followed by logarithm.
See
softmax
for more details.- Parameters
input (Tensor) – input
dim (int) – A dimension along which softmax will be computed.
dtype (
torch.dtype
, optional) – the desired data type of returned tensor. If specified, the input tensor is casted todtype
before the operation is performed. This is useful for preventing data type overflows. Default: None
Other functions¶
The following torch
functions support sparse COO tensors:
cat()
dstack()
empty()
empty_like()
hstack()
index_select()
is_complex()
is_floating_point()
is_nonzero()
is_same_size()
is_signed()
is_tensor()
lobpcg()
mm()
native_norm()
pca_lowrank()
select()
stack()
svd_lowrank()
unsqueeze()
vstack()
zeros()
zeros_like()