Shortcuts

torch.cuda

This package adds support for CUDA tensor types, that implement the same function as CPU tensors, but they utilize GPUs for computation.

It is lazily initialized, so you can always import it, and use is_available() to determine if your system supports CUDA.

CUDA semantics has more details about working with CUDA.

torch.cuda.can_device_access_peer(device, peer_device)[source]

Checks if peer access between two devices is possible.

torch.cuda.current_blas_handle()[source]

Returns cublasHandle_t pointer to current cuBLAS handle

torch.cuda.current_device()[source]

Returns the index of a currently selected device.

torch.cuda.current_stream(device=None)[source]

Returns the currently selected Stream for a given device.

Parameters

device (torch.device or int, optional) – selected device. Returns the currently selected Stream for the current device, given by current_device(), if device is None (default).

torch.cuda.default_stream(device=None)[source]

Returns the default Stream for a given device.

Parameters

device (torch.device or int, optional) – selected device. Returns the default Stream for the current device, given by current_device(), if device is None (default).

class torch.cuda.device(device)[source]

Context-manager that changes the selected device.

Parameters

device (torch.device or int) – device index to select. It’s a no-op if this argument is a negative integer or None.

torch.cuda.device_count()[source]

Returns the number of GPUs available.

class torch.cuda.device_of(obj)[source]

Context-manager that changes the current device to that of given object.

You can use both tensors and storages as arguments. If a given object is not allocated on a GPU, this is a no-op.

Parameters

obj (Tensor or Storage) – object allocated on the selected device.

torch.cuda.get_arch_list()[source]

Returns list CUDA architectures this library was compiled for.

torch.cuda.get_device_capability(device=None)[source]

Gets the cuda capability of a device.

Parameters

device (torch.device or int, optional) – device for which to return the device capability. This function is a no-op if this argument is a negative integer. It uses the current device, given by current_device(), if device is None (default).

Returns

the major and minor cuda capability of the device

Return type

tuple(int, int)

torch.cuda.get_device_name(device=None)[source]

Gets the name of a device.

Parameters

device (torch.device or int, optional) – device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by current_device(), if device is None (default).

Returns

the name of the device

Return type

str

torch.cuda.get_device_properties(device)[source]

Gets the properties of a device.

Parameters

device (torch.device or int or str) – device for which to return the properties of the device.

Returns

the properties of the device

Return type

_CudaDeviceProperties

torch.cuda.get_gencode_flags()[source]

Returns NVCC gencode flags this library were compiled with.

torch.cuda.init()[source]

Initialize PyTorch’s CUDA state. You may need to call this explicitly if you are interacting with PyTorch via its C API, as Python bindings for CUDA functionality will not be available until this initialization takes place. Ordinary users should not need this, as all of PyTorch’s CUDA methods automatically initialize CUDA state on-demand.

Does nothing if the CUDA state is already initialized.

torch.cuda.ipc_collect()[source]

Force collects GPU memory after it has been released by CUDA IPC.

Note

Checks if any sent CUDA tensors could be cleaned from the memory. Force closes shared memory file used for reference counting if there is no active counters. Useful when the producer process stopped actively sending tensors and want to release unused memory.

torch.cuda.is_available()[source]

Returns a bool indicating if CUDA is currently available.

torch.cuda.is_initialized()[source]

Returns whether PyTorch’s CUDA state has been initialized.

torch.cuda.set_device(device)[source]

Sets the current device.

Usage of this function is discouraged in favor of device. In most cases it’s better to use CUDA_VISIBLE_DEVICES environmental variable.

Parameters

device (torch.device or int) – selected device. This function is a no-op if this argument is negative.

torch.cuda.stream(stream)[source]

Context-manager that selects a given stream.

All CUDA kernels queued within its context will be enqueued on a selected stream.

Parameters

stream (Stream) – selected stream. This manager is a no-op if it’s None.

Note

Streams are per-device. If the selected stream is not on the current device, this function will also change the current device to match the stream.

torch.cuda.synchronize(device=None)[source]

Waits for all kernels in all streams on a CUDA device to complete.

Parameters

device (torch.device or int, optional) – device for which to synchronize. It uses the current device, given by current_device(), if device is None (default).

Random Number Generator

torch.cuda.get_rng_state(device='cuda')[source]

Returns the random number generator state of the specified GPU as a ByteTensor.

Parameters

device (torch.device or int, optional) – The device to return the RNG state of. Default: 'cuda' (i.e., torch.device('cuda'), the current CUDA device).

Warning

This function eagerly initializes CUDA.

torch.cuda.get_rng_state_all()[source]

Returns a list of ByteTensor representing the random number states of all devices.

torch.cuda.set_rng_state(new_state, device='cuda')[source]

Sets the random number generator state of the specified GPU.

Parameters
  • new_state (torch.ByteTensor) – The desired state

  • device (torch.device or int, optional) – The device to set the RNG state. Default: 'cuda' (i.e., torch.device('cuda'), the current CUDA device).

torch.cuda.set_rng_state_all(new_states)[source]

Sets the random number generator state of all devices.

Parameters

new_states (Iterable of torch.ByteTensor) – The desired state for each device

torch.cuda.manual_seed(seed)[source]

Sets the seed for generating random numbers for the current GPU. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.

Parameters

seed (int) – The desired seed.

Warning

If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all().

torch.cuda.manual_seed_all(seed)[source]

Sets the seed for generating random numbers on all GPUs. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.

Parameters

seed (int) – The desired seed.

torch.cuda.seed()[source]

Sets the seed for generating random numbers to a random number for the current GPU. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.

Warning

If you are working with a multi-GPU model, this function will only initialize the seed on one GPU. To initialize all GPUs, use seed_all().

torch.cuda.seed_all()[source]

Sets the seed for generating random numbers to a random number on all GPUs. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.

torch.cuda.initial_seed()[source]

Returns the current random seed of the current GPU.

Warning

This function eagerly initializes CUDA.

Communication collectives

torch.cuda.comm.broadcast(tensor, devices=None, *, out=None)[source]

Broadcasts a tensor to specified GPU devices.

Parameters
  • tensor (Tensor) – tensor to broadcast. Can be on CPU or GPU.

  • devices (Iterable[torch.device, str or int], optional) – an iterable of GPU devices, among which to broadcast.

  • out (Sequence[Tensor], optional, keyword-only) – the GPU tensors to store output results.

Note

Exactly one of devices and out must be specified.

Returns

  • If devices is specified,

    a tuple containing copies of tensor, placed on devices.

  • If out is specified,

    a tuple containing out tensors, each containing a copy of tensor.

torch.cuda.comm.broadcast_coalesced(tensors, devices, buffer_size=10485760)[source]

Broadcasts a sequence tensors to the specified GPUs. Small tensors are first coalesced into a buffer to reduce the number of synchronizations.

Parameters
  • tensors (sequence) – tensors to broadcast. Must be on the same device, either CPU or GPU.

  • devices (Iterable[torch.device, str or int]) – an iterable of GPU devices, among which to broadcast.

  • buffer_size (int) – maximum size of the buffer used for coalescing

Returns

A tuple containing copies of tensor, placed on devices.

torch.cuda.comm.reduce_add(inputs, destination=None)[source]

Sums tensors from multiple GPUs.

All inputs should have matching shapes, dtype, and layout. The output tensor will be of the same shape, dtype, and layout.

Parameters
  • inputs (Iterable[Tensor]) – an iterable of tensors to add.

  • destination (int, optional) – a device on which the output will be placed (default: current device).

Returns

A tensor containing an elementwise sum of all inputs, placed on the destination device.

torch.cuda.comm.scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None)[source]

Scatters tensor across multiple GPUs.

Parameters
  • tensor (Tensor) – tensor to scatter. Can be on CPU or GPU.

  • devices (Iterable[torch.device, str or int], optional) – an iterable of GPU devices, among which to scatter.

  • chunk_sizes (Iterable[int], optional) – sizes of chunks to be placed on each device. It should match devices in length and sums to tensor.size(dim). If not specified, tensor will be divided into equal chunks.

  • dim (int, optional) – A dimension along which to chunk tensor. Default: 0.

  • streams (Iterable[Stream], optional) – an iterable of Streams, among which to execute the scatter. If not specified, the default stream will be utilized.

  • out (Sequence[Tensor], optional, keyword-only) – the GPU tensors to store output results. Sizes of these tensors must match that of tensor, except for dim, where the total size must sum to tensor.size(dim).

Note

Exactly one of devices and out must be specified. When out is specified, chunk_sizes must not be specified and will be inferred from sizes of out.

Returns

  • If devices is specified,

    a tuple containing chunks of tensor, placed on devices.

  • If out is specified,

    a tuple containing out tensors, each containing a chunk of tensor.

torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[source]

Gathers tensors from multiple GPU devices.

Parameters
  • tensors (Iterable[Tensor]) – an iterable of tensors to gather. Tensor sizes in all dimensions other than dim have to match.

  • dim (int, optional) – a dimension along which the tensors will be concatenated. Default: 0.

  • destination (torch.device, str, or int, optional) – the output device. Can be CPU or CUDA. Default: the current CUDA device.

  • out (Tensor, optional, keyword-only) – the tensor to store gather result. Its sizes must match those of tensors, except for dim, where the size must equal sum(tensor.size(dim) for tensor in tensors). Can be on CPU or CUDA.

Note

destination must not be specified when out is specified.

Returns

  • If destination is specified,

    a tensor located on destination device, that is a result of concatenating tensors along dim.

  • If out is specified,

    the out tensor, now containing results of concatenating tensors along dim.

Streams and events

class torch.cuda.Stream(device=None, priority=0, **kwargs)[source]

Wrapper around a CUDA stream.

A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See CUDA semantics for details.

Parameters
  • device (torch.device or int, optional) – a device on which to allocate the stream. If device is None (default) or a negative integer, this will use the current device.

  • priority (int, optional) – priority of the stream. Can be either -1 (high priority) or 0 (low priority). By default, streams have priority 0.

Note

Although CUDA versions >= 11 support more than two levels of priorities, in PyTorch, we only support two levels of priorities.

query()[source]

Checks if all the work submitted has been completed.

Returns

A boolean indicating if all kernels in this stream are completed.

record_event(event=None)[source]

Records an event.

Parameters

event (Event, optional) – event to record. If not given, a new one will be allocated.

Returns

Recorded event.

synchronize()[source]

Wait for all the kernels in this stream to complete.

Note

This is a wrapper around cudaStreamSynchronize(): see CUDA Stream documentation for more info.

wait_event(event)[source]

Makes all future work submitted to the stream wait for an event.

Parameters

event (Event) – an event to wait for.

Note

This is a wrapper around cudaStreamWaitEvent(): see CUDA Stream documentation for more info.

This function returns without waiting for event: only future operations are affected.

wait_stream(stream)[source]

Synchronizes with another stream.

All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete.

Parameters

stream (Stream) – a stream to synchronize.

Note

This function returns without waiting for currently enqueued kernels in stream: only future operations are affected.

class torch.cuda.Event(enable_timing=False, blocking=False, interprocess=False)[source]

Wrapper around a CUDA event.

CUDA events are synchronization markers that can be used to monitor the device’s progress, to accurately measure timing, and to synchronize CUDA streams.

The underlying CUDA events are lazily initialized when the event is first recorded or exported to another process. After creation, only streams on the same device may record the event. However, streams on any device can wait on the event.

Parameters
  • enable_timing (bool, optional) – indicates if the event should measure time (default: False)

  • blocking (bool, optional) – if True, wait() will be blocking (default: False)

  • interprocess (bool) – if True, the event can be shared between processes (default: False)

elapsed_time(end_event)[source]

Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded.

classmethod from_ipc_handle(device, handle)[source]

Reconstruct an event from an IPC handle on the given device.

ipc_handle()[source]

Returns an IPC handle of this event. If not recorded yet, the event will use the current device.

query()[source]

Checks if all work currently captured by event has completed.

Returns

A boolean indicating if all work currently captured by event has completed.

record(stream=None)[source]

Records the event in a given stream.

Uses torch.cuda.current_stream() if no stream is specified. The stream’s device must match the event’s device.

synchronize()[source]

Waits for the event to complete.

Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes.

Note

This is a wrapper around cudaEventSynchronize(): see CUDA Event documentation for more info.

wait(stream=None)[source]

Makes all future work submitted to the given stream wait for this event.

Use torch.cuda.current_stream() if no stream is specified.

Memory management

torch.cuda.empty_cache()[source]

Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.

Note

empty_cache() doesn’t increase the amount of GPU memory available for PyTorch. However, it may help reduce fragmentation of GPU memory in certain cases. See Memory management for more details about GPU memory management.

torch.cuda.list_gpu_processes(device=None)[source]

Returns a human-readable printout of the running processes and their GPU memory use for a given device.

This can be useful to display periodically during training, or when handling out-of-memory exceptions.

Parameters

device (torch.device or int, optional) – selected device. Returns printout for the current device, given by current_device(), if device is None (default).

torch.cuda.memory_stats(device=None)[source]

Returns a dictionary of CUDA memory allocator statistics for a given device.

The return value of this function is a dictionary of statistics, each of which is a non-negative integer.

Core statistics:

  • "allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of allocation requests received by the memory allocator.

  • "allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of allocated memory.

  • "segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of reserved segments from cudaMalloc().

  • "reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of reserved memory.

  • "active.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of active memory blocks.

  • "active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of active memory.

  • "inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of inactive, non-releasable memory blocks.

  • "inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of inactive, non-releasable memory.

For these core statistics, values are broken down as follows.

Pool type:

  • all: combined statistics across all memory pools.

  • large_pool: statistics for the large allocation pool (as of October 2019, for size >= 1MB allocations).

  • small_pool: statistics for the small allocation pool (as of October 2019, for size < 1MB allocations).

Metric type:

  • current: current value of this metric.

  • peak: maximum value of this metric.

  • allocated: historical total increase in this metric.

  • freed: historical total decrease in this metric.

In addition to the core statistics, we also provide some simple event counters:

  • "num_alloc_retries": number of failed cudaMalloc calls that result in a cache flush and retry.

  • "num_ooms": number of out-of-memory errors thrown.

Parameters

device (torch.device or int, optional) – selected device. Returns statistics for the current device, given by current_device(), if device is None (default).

Note

See Memory management for more details about GPU memory management.

torch.cuda.memory_summary(device=None, abbreviated=False)[source]

Returns a human-readable printout of the current memory allocator statistics for a given device.

This can be useful to display periodically during training, or when handling out-of-memory exceptions.

Parameters
  • device (torch.device or int, optional) – selected device. Returns printout for the current device, given by current_device(), if device is None (default).

  • abbreviated (bool, optional) – whether to return an abbreviated summary (default: False).

Note

See Memory management for more details about GPU memory management.

torch.cuda.memory_snapshot()[source]

Returns a snapshot of the CUDA memory allocator state across all devices.

Interpreting the output of this function requires familiarity with the memory allocator internals.

Note

See Memory management for more details about GPU memory management.

torch.cuda.memory_allocated(device=None)[source]

Returns the current GPU memory occupied by tensors in bytes for a given device.

Parameters

device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device(), if device is None (default).

Note

This is likely less than the amount shown in nvidia-smi since some unused memory can be held by the caching allocator and some context needs to be created on GPU. See Memory management for more details about GPU memory management.

torch.cuda.max_memory_allocated(device=None)[source]

Returns the maximum GPU memory occupied by tensors in bytes for a given device.

By default, this returns the peak allocated memory since the beginning of this program. reset_peak_stats() can be used to reset the starting point in tracking this metric. For example, these two functions can measure the peak allocated memory usage of each iteration in a training loop.

Parameters

device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device(), if device is None (default).

Note

See Memory management for more details about GPU memory management.

torch.cuda.reset_max_memory_allocated(device=None)[source]

Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.

See max_memory_allocated() for details.

Parameters

device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device(), if device is None (default).

Warning

This function now calls reset_peak_memory_stats(), which resets /all/ peak memory stats.

Note

See Memory management for more details about GPU memory management.

torch.cuda.memory_reserved(device=None)[source]

Returns the current GPU memory managed by the caching allocator in bytes for a given device.

Parameters

device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device(), if device is None (default).

Note

See Memory management for more details about GPU memory management.

torch.cuda.max_memory_reserved(device=None)[source]

Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.

By default, this returns the peak cached memory since the beginning of this program. reset_peak_stats() can be used to reset the starting point in tracking this metric. For example, these two functions can measure the peak cached memory amount of each iteration in a training loop.

Parameters

device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device(), if device is None (default).

Note

See Memory management for more details about GPU memory management.

torch.cuda.set_per_process_memory_fraction(fraction, device=None)[source]

Set memory fraction for a process. The fraction is used to limit an caching allocator to allocated memory on a CUDA device. The allowed value equals the total visible memory multiplied fraction. If trying to allocate more than the allowed value in a process, will raise an out of memory error in allocator.

Parameters
  • fraction (float) – Range: 0~1. Allowed memory equals total_memory * fraction.

  • device (torch.device or int, optional) – selected device. If it is None the default CUDA device is used.

Note

In general, the total available free memory is less than the total capacity.

torch.cuda.memory_cached(device=None)[source]

Deprecated; see memory_reserved().

torch.cuda.max_memory_cached(device=None)[source]

Deprecated; see max_memory_reserved().

torch.cuda.reset_max_memory_cached(device=None)[source]

Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.

See max_memory_cached() for details.

Parameters

device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by current_device(), if device is None (default).

Warning

This function now calls reset_peak_memory_stats(), which resets /all/ peak memory stats.

Note

See Memory management for more details about GPU memory management.

NVIDIA Tools Extension (NVTX)

torch.cuda.nvtx.mark(msg)[source]

Describe an instantaneous event that occurred at some point.

Parameters

msg (string) – ASCII message to associate with the event.

torch.cuda.nvtx.range_push(msg)[source]

Pushes a range onto a stack of nested range span. Returns zero-based depth of the range that is started.

Parameters

msg (string) – ASCII message to associate with range

torch.cuda.nvtx.range_pop()[source]

Pops a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources