torch.use_deterministic_algorithms¶
-
torch.use_deterministic_algorithms(d)[source]¶ Sets whether PyTorch operations must use “deterministic” algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When True, operations will use deterministic algorithms when available, and if only nondeterministic algorithms are available they will throw a :class:RuntimeError when called.
Warning
This feature is in beta, and its design and implementation may change in the future.
The following normally-nondeterministic operations will act deterministically when d=True:
torch.nn.Conv1dwhen called on CUDA tensortorch.nn.Conv2dwhen called on CUDA tensortorch.nn.Conv3dwhen called on CUDA tensortorch.nn.ConvTranspose1dwhen called on CUDA tensortorch.nn.ConvTranspose2dwhen called on CUDA tensortorch.nn.ConvTranspose3dwhen called on CUDA tensortorch.bmm()when called on sparse-dense CUDA tensors
The following normally-nondeterministic operations will throw a
RuntimeErrorwhen d=True:torch.nn.AvgPool3dwhen called on a CUDA tensor that requires gradtorch.nn.AdaptiveAvgPool2dwhen called on a CUDA tensor that requires gradtorch.nn.AdaptiveAvgPool3dwhen called on a CUDA tensor that requires gradtorch.nn.MaxPool3dwhen called on a CUDA tensor that requires gradtorch.nn.AdaptiveMaxPool2dwhen called on a CUDA tensor that requires gradtorch.nn.FractionalMaxPool2dwhen called on a CUDA tensor that requires gradtorch.nn.FractionalMaxPool3dwhen called on a CUDA tensor that requires gradtorch.nn.functional.interpolate()when called on a CUDA tensor that requires grad and one of the following modes is used:linear
bilinear
bicubic
trilinear
torch.nn.ReflectionPad1dwhen called on a CUDA tensor that requires gradtorch.nn.ReflectionPad2dwhen called on a CUDA tensor that requires gradtorch.nn.ReplicationPad1dwhen called on a CUDA tensor that requires gradtorch.nn.ReplicationPad2dwhen called on a CUDA tensor that requires gradtorch.nn.ReplicationPad3dwhen called on a CUDA tensor that requires gradtorch.nn.NLLLosswhen called on a CUDA tensor that requires gradtorch.nn.CTCLosswhen called on a CUDA tensor that requires gradtorch.nn.EmbeddingBagwhen called on a CUDA tensor that requires gradtorch.scatter_add_()when called on a CUDA tensortorch.index_add_()when called on a CUDA tensortorch.index_copy()torch.index_select()when called on a CUDA tensor that requires gradtorch.repeat_interleave()when called on a CUDA tensor that requires gradtorch.histc()when called on a CUDA tensortorch.bincount()when called on a CUDA tensortorch.kthvalue()with called on a CUDA tensortorch.median()with indices output when called on a CUDA tensor
A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater, unless the environment variable CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8 is set. See the CUDA documentation for more details: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility If one of these environment variable configurations is not set, a
RuntimeErrorwill be raised from these operations when called with CUDA tensors:Note that deterministic operations tend to have worse performance than non-deterministic operations.
- Parameters
d (
bool) – If True, force operations to be deterministic. If False, allow non-deterministic operations.