Skip to content

Tensor.sum() over multiple axes #2006

@vlasenkov

Description

@vlasenkov

During writing einsum I've found that sum is able to sum only over single axis. I'm going to add a simple implementation of it:

def sum(input, axes, keepdim=False):
    # probably some check for uniqueness of axes
    if keepdim:
        for ax in axes:
            input = input.sum(ax, keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            input = input.sum(ax)
    return input
  1. Is it ok?
  2. Where should it be better placed?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions