torch.profiler¶
Overview¶
PyTorch Profiler is a tool that allows the collecton of the performance metrics during the training and inference. Profiler’s context manager API can be used to better understand what model operators are the most expensive, examine their input shapes and stack traces, study device kernel activity and visualize the execution trace.
Note
An earlier version of the API in torch.autograd
module is considered legacy and will be deprecated.
API Reference¶
-
class
torch.profiler.
profile
(*, activities=None, schedule=None, on_trace_ready=None, record_shapes=False, profile_memory=False, with_stack=False, use_gpu=None)[source]¶ Profiler context manager.
Args:
activities
- list of activity groups (CPU, CUDA) to use in profiling, supported values:torch.profiler.ProfilerActivity.CPU
,torch.profiler.ProfilerActivity.CUDA
schedule
- callable that takes step (int) as a single parameter and returnsProfilerAction
value that specifies the profiler action to perform at each step;on_trace_ready
- callable that is called at each step whenschedule
returnsProfilerAction.RECORD_AND_SAVE
during the profiling;record_shapes
- save information about operator’s input shapes;profile_memory
- track tensor memory allocation/deallocation;with_stack
- record source information (file and line number) for the ops.use_gpu
- (deprecated, useactivities
).
Note
Use
torch.profiler.schedule
to generate the callable schedule. Non-default schedules are useful when profiling long training jobs and allow the user to obtain multiple traces at the different iterations of the training process. The default schedule simply records all the events continuously for the duration of the context manager.Note
Enabling shape and stack tracing results in additional overhead.
Examples:
with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA] ) as p: code_to_profile() print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1))
Using the profiler’s
schedule
,on_trace_ready
andstep
functions:# Non-default profiler schedule allows user to turn profiler on and off # on different iterations of the training loop; # trace_handler is called every time a new trace becomes available def trace_handler(prof): print(prof.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], # In this example with wait=1, warmup=1, active=2, # profiler will skip the first step/iteration, # start warming up on the second, record # the third and the forth iterations, # after which the trace will become available # and on_trace_ready (when set) is called; # the cycle repeats starting with the next step schedule=torch.profiler.schedule( wait=1, warmup=1, active=2), on_trace_ready=trace_handler ) as p: for iter in range(N): code_iteration_to_profile(iter) # send a signal to the profiler that the next iteration has started p.step()
-
events
()[source]¶ Returns the list of unaggregated profiler events, to be used in the trace callback or after the profiling is finished
-
export_stacks
(path, metric='self_cpu_time_total')[source]¶ Save stack traces in a file in a format suitable for visualization.
Args:
path
- save stacks file to this location;metric
- metric to use: “self_cpu_time_total” or “self_cuda_time_total”
Note
Example of using FlameGraph tool:
cd FlameGraph
./flamegraph.pl –title “CPU time” –countname “us.” profiler.stacks > perf_viz.svg
-
torch.profiler.
schedule
(*, wait, warmup, active)[source]¶ Returns a callable that can be used as profiler
schedule
argument. The profiler will wait forwait
steps, then do the warmup for the nextwarmup
steps, then do the active recording for the nextactive
steps and then repeat the cycle staring with the next step.