-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
#25799 provides optimized implementation for cdist with p=2 that uses matrix multiplies instead of specialized (inefficient) kernels. However, that PR did not change the backward logic, so backward even for the optimized case calls inefficient kernels. The logic for cdist operation should be rewritten in such a way that euclidian distance becomes a compound operation w/o backward defined (thus the backward would be optimized automatically), and the general case can go through the inefficient kernels.
Motivation
Optimized forward implementation improves performance by upto 50x, we expect similar improvements for backward.
Also, current implementation suffers from the bugs with large sizes (see #31167). Those bugs still have to be fixed for the general case, but gemm-based implementation seems more robust. Large memory usage reported in #24345 also might be due to inefficient implementation.
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @VitalyFedyunin @ngimel