-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Is your feature request related to a problem? Please describe.
The existing communication logging api is somewhat limited in its ability to
- Provide raw data about communication operations
- Do fine grained profiling of communications
Toward number one, currently the only way to return information about the collected logs is to call deepspeed.comms.log_all(). This function allows us to print the logs/summary information to stdout by passing print_log=true. First, this setup doesn't really make any sense, as how it stands, if you call deepspeed.comms.log_all() without this parameter, then the function basically does nothing in terms of output. The values are calculated, but never returned or using in any way, other than if print_log=true. By returning the actual data, I can use this more effectively and plot thing directly into tensorboard or wandb.
Towards number two. Right now I am profiling a model and am hitting issues with communication bandwidth. I would like to be able to plot an epoch by epoch analysis of my collective operations. Right now, this can technically be done by plotting the aggregate metrics generated from deepspeed.comms.log_all() , but if I want just the times for that epoch, I would have to subtract from whatever the aggregated values were from the previously logged epoch. It would be much easier if instead we were able to clear the log manually.
Describe the solution you'd like
I would like the communication profiler to return a dictionary of all of its results and for the ability to clear the logs manually.
Describe alternatives you've considered
- Parsing data from what is printed to stdout and manually calculating values during post-training