@@ -19,7 +19,7 @@ namespace at {
1919// NOTE: [When should I add a batching rule?]
2020// When you are adding a new operator, you'll need to add a batching rule so
2121// that vmap can work efficiently with said operator. If you do not, we'll attempt
22- // to generate a slow fallback for the batching rule (this is not yet implemented) .
22+ // to generate a slow fallback for the batching rule.
2323
2424// NOTE: [How to write batching rules?]
2525// The signature of a batching rule should look like exactly like the C++ signature
@@ -223,13 +223,33 @@ Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
223223 return self_physical.newLogicalFromPhysical (result);
224224}
225225
226+ static int64_t getGradInputPhysicalDim (int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
227+ return maybe_wrap_dim (dim, input_sizes.size ()) + num_batch_dims;
228+ }
229+
230+ Tensor select_backward_batching_rule (const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
231+ auto grad_physical = MultiBatchVmapTransform::logicalToPhysical (grad);
232+ auto grad_input = at::zeros (grad_physical.getPhysicalShape (input_sizes), grad.options ());
233+ auto physical_dim = getGradInputPhysicalDim (dim, input_sizes, grad_physical.numBatchDims ());
234+ grad_input.select (physical_dim, index).copy_ (grad_physical.tensor ());
235+ return grad_physical.newLogicalFromPhysical (grad_input);
236+ }
237+
226238Tensor slice_batching_rule (const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
227239 auto self_physical = MultiBatchVmapTransform::logicalToPhysical (self);
228240 auto dim_physical = self_physical.getPhysicalDim (dim);
229241 auto result = self_physical.tensor ().slice (dim_physical, start, end, step);
230242 return self_physical.newLogicalFromPhysical (result);
231243}
232244
245+ Tensor slice_backward_batching_rule (const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
246+ auto grad_physical = MultiBatchVmapTransform::logicalToPhysical (grad);
247+ auto grad_input = at::zeros (grad_physical.getPhysicalShape (input_sizes), grad.options ());
248+ auto physical_dim = getGradInputPhysicalDim (dim, input_sizes, grad_physical.numBatchDims ());
249+ grad_input.slice (physical_dim, start, end, step).copy_ (grad_physical.tensor ());
250+ return grad_physical.newLogicalFromPhysical (grad_input);
251+ }
252+
233253Tensor diagonal_batching_rule (const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
234254 auto self_physical = MultiBatchVmapTransform::logicalToPhysical (self);
235255 auto dim1_physical = self_physical.getPhysicalDim (dim1);
@@ -238,6 +258,15 @@ Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1,
238258 return self_physical.newLogicalFromPhysical (result);
239259}
240260
261+ Tensor diagonal_backward_batching_rule (const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
262+ auto grad_physical = MultiBatchVmapTransform::logicalToPhysical (grad);
263+ auto grad_input = at::zeros (grad_physical.getPhysicalShape (input_sizes), grad.options ());
264+ auto dim1_physical = getGradInputPhysicalDim (dim1, input_sizes, grad_physical.numBatchDims ());
265+ auto dim2_physical = getGradInputPhysicalDim (dim2, input_sizes, grad_physical.numBatchDims ());
266+ grad_input.diagonal (offset, dim1_physical, dim2_physical).copy_ (grad_physical.tensor ());
267+ return grad_physical.newLogicalFromPhysical (grad_input);
268+ }
269+
241270Tensor movedim_batching_rule (const Tensor& self, IntArrayRef source, IntArrayRef destination) {
242271 auto self_physical = MultiBatchVmapTransform::logicalToPhysical (self);
243272 auto source_physical = self_physical.getPhysicalDims (source);
@@ -614,6 +643,11 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
614643 // cat/stack
615644 m.impl (" cat" , cat_batching_rule);
616645 m.impl (" stack" , stack_batching_rule);
646+
647+ // backward operators
648+ m.impl (" select_backward" , select_backward_batching_rule);
649+ m.impl (" slice_backward" , slice_backward_batching_rule);
650+ m.impl (" diagonal_backward" , diagonal_backward_batching_rule);
617651}
618652
619653} // namespace at
0 commit comments