Skip to content

Commit 6703587

Browse files
zou3519facebook-github-bot
authored andcommitted
Delete tagged names
Summary: Pull Request resolved: #26365 Test Plan: - [namedtensor ci] Differential Revision: D17484759 Pulled By: zou3519 fbshipit-source-id: 44068c1e9d84adf36c5ab5e7006a153b948914d6
1 parent 858cf76 commit 6703587

File tree

8 files changed

+25
-93
lines changed

8 files changed

+25
-93
lines changed

aten/src/ATen/NamedTensorUtils.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,10 @@ int64_t dimname_to_position(const Tensor& tensor, Dimname dim) {
2020
"Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
2121
const auto names = tensor.names();
2222

23-
const auto it = std::find_if(
24-
names.begin(), names.end(),
25-
[&dim](const Dimname& candidate) { return dim.can_refer_to(candidate); });
23+
const auto it = std::find(names.begin(), names.end(), dim);
2624
TORCH_CHECK(it != names.end(),
2725
"Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
2826

29-
// Check that it can't refer to another dimension
30-
const auto dup = std::find_if(
31-
it + 1, names.end(),
32-
[&dim](const Dimname& candidate) { return dim.can_refer_to(candidate); });
33-
TORCH_CHECK(
34-
dup == names.end(),
35-
"Name ", dim, " could refer to multiple dimensions in ",
36-
toDimnameRepr(tensor), ". Please disambiguate by using a more ",
37-
"specific name like ", *it, " or ", dup, ".");
3827
return std::distance(names.begin(), it);
3928
}
4029

@@ -68,8 +57,7 @@ static void check_for_misalignment(
6857
if (name.is_wildcard()) {
6958
return;
7059
}
71-
auto it = std::find_if(other_names.begin(), other_names.end(),
72-
[&](const Dimname& candidate) { return name.can_refer_to(candidate); });
60+
auto it = std::find(other_names.begin(), other_names.end(), name);
7361
// TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
7462
TORCH_CHECK(it == other_names.end(),
7563
"Misaligned dims when attempting to ", action, " dims ", names, " and dims ",
@@ -94,11 +82,6 @@ std::vector<Dimname> unify_from_right(
9482
const auto& name = names_it == names.rend() ? wildcard : *names_it;
9583
const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it;
9684

97-
// TODO(zou3519): Don't support tagged names for now. They're a little weird.
98-
if (name.is_tagged() || other_name.is_tagged()) {
99-
TORCH_INTERNAL_ASSERT("unify_from_right: NYI: tagged names.");
100-
}
101-
10285
// Step 1: Check that the names match
10386
const auto maybeName = unify(name, other_name);
10487
if (!maybeName) {

aten/src/ATen/core/Dimname.cpp

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
1010
if (dimname.type() == NameType::WILDCARD) {
1111
out << "None";
1212
} else {
13-
out << "'" << dimname.full_name().toUnqualString() << "'";
13+
out << "'" << dimname.symbol().toUnqualString() << "'";
1414
}
1515
return out;
1616
}
@@ -28,38 +28,24 @@ bool is_valid_identifier(const std::string& name) {
2828
return true;
2929
}
3030

31-
bool Dimname::can_refer_to(const Dimname& other) const {
32-
switch (type()) {
33-
case NameType::WILDCARD:
34-
return false;
35-
36-
// "C" can be used to refer to "C" or "C.in".
37-
case NameType::NORMAL:
38-
return untagged_name() == other.untagged_name();
39-
40-
default:
41-
return full_name() == other.full_name();
42-
}
43-
}
44-
4531
static void check_valid_identifier(const std::string& name) {
4632
TORCH_CHECK(
4733
is_valid_identifier(name),
4834
"Invalid name: a valid identifier must contain alphabetical characters and/or underscore, got: '",
4935
name, "'.");
5036
}
5137

52-
Dimname Dimname::fromSymbol(Symbol full_name) {
53-
TORCH_INTERNAL_ASSERT(full_name.is_dimname());
54-
if (full_name == kWildcard) {
38+
Dimname Dimname::fromSymbol(Symbol name) {
39+
TORCH_INTERNAL_ASSERT(name.is_dimname());
40+
if (name == kWildcard) {
5541
return Dimname::wildcard();
5642
}
57-
check_valid_identifier(full_name.toUnqualString());
58-
return Dimname(full_name);
43+
check_valid_identifier(name.toUnqualString());
44+
return Dimname(name);
5945
}
6046

6147
Dimname Dimname::wildcard() {
62-
static Dimname result(NameType::WILDCARD, kWildcard, kWildcard);
48+
static Dimname result(kWildcard, NameType::WILDCARD);
6349
return result;
6450
}
6551

@@ -70,12 +56,9 @@ optional<Dimname> unify(Dimname dimname, Dimname other) {
7056
if (dimname.type() == NameType::WILDCARD) {
7157
return other;
7258
}
73-
if (dimname.full_name() == other.full_name()) {
59+
if (dimname.symbol() == other.symbol()) {
7460
return dimname;
7561
}
76-
if (dimname.untagged_name() == other.untagged_name()) {
77-
return Dimname::fromSymbol(dimname.untagged_name());
78-
}
7962
return c10::nullopt;
8063
}
8164

aten/src/ATen/core/Dimname.h

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,26 @@
99

1010
namespace at {
1111

12-
enum class NameType: uint8_t { NORMAL, WILDCARD, TAGGED };
12+
enum class NameType: uint8_t { NORMAL, WILDCARD };
1313

1414
struct CAFFE2_API Dimname {
1515
static Dimname fromSymbol(Symbol name);
1616
static Dimname wildcard();
1717

1818
NameType type() const { return type_; }
19-
Symbol full_name() const { return full_name_; }
20-
Symbol untagged_name() const { return untagged_name_; }
21-
22-
bool can_refer_to(const Dimname& other) const;
19+
Symbol symbol() const { return name_; }
2320

2421
bool is_normal() const { return type_ == NameType::NORMAL; }
2522
bool is_wildcard() const { return type_ == NameType::WILDCARD; }
26-
bool is_tagged() const { return type_ == NameType::TAGGED; }
2723

2824
private:
2925
Dimname(Symbol name)
30-
: untagged_name_(name), full_name_(name), type_(NameType::NORMAL) {}
31-
Dimname(NameType type, Symbol full_name, Symbol untagged_name)
32-
: untagged_name_(untagged_name), full_name_(full_name), type_(type) {}
26+
: name_(name), type_(NameType::NORMAL) {}
27+
Dimname(Symbol name, NameType type)
28+
: name_(name), type_(type) {}
3329

34-
// [Dimname Terminology]
35-
//
36-
// For "C.in":
37-
// - "C.in" is the "full name"
38-
// - "C" is the "untagged name"
39-
// - "in" is the "tag"
40-
Symbol untagged_name_;
41-
Symbol full_name_;
30+
Symbol name_;
4231
NameType type_;
43-
// Will need more fields for other special name types.
4432
};
4533

4634
using DimnameList = c10::ArrayRef<Dimname>;
@@ -54,7 +42,7 @@ CAFFE2_API bool match(Dimname dimname, Dimname other);
5442
CAFFE2_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
5543

5644
inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
57-
return lhs.full_name() == rhs.full_name();
45+
return lhs.symbol() == rhs.symbol();
5846
}
5947

6048
inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {

aten/src/ATen/core/NamedTensor.cpp

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,36 +48,16 @@ void check_names_valid_for(const Tensor& tensor, DimnameList names) {
4848

4949
namespace impl {
5050

51-
// Two Dimnames cannot be in the same Tensor if one of them can refer to the other.
52-
// In practice, this constraint means that a Tensor cannot have duplicate names
53-
// unless they are tagged and the tags are different.
54-
static DimnameList::const_iterator find_incompatible_name(
55-
DimnameList::const_iterator begin,
56-
DimnameList::const_iterator end,
57-
const Dimname& target) {
58-
return std::find_if(begin, end,
59-
[&target](const Dimname& candidate) {
60-
return target.can_refer_to(candidate) || candidate.can_refer_to(target);
61-
});
62-
}
63-
6451
static void check_unique_names(DimnameList names) {
6552
// Strategy: Compare each element with the ones that come after it.
6653
// Although this is O(N^2), in practice N is small (no more than 25).
6754
for (auto it = names.begin(); it != names.end(); ++it) {
68-
auto dup = find_incompatible_name(it + 1, names.end(), *it);
55+
if (it->is_wildcard()) continue;
56+
auto dup = std::find(it + 1, names.end(), *it);
6957
while (dup != names.end()) {
70-
// Simple error message if you're not using tags
71-
TORCH_CHECK(it->type() == NameType::TAGGED || dup->type() == NameType::TAGGED,
58+
TORCH_CHECK(false,
7259
"Cannot construct a tensor with duplicate names. Got names: ",
7360
names, ".");
74-
75-
// Complicated error message if you're using tags
76-
TORCH_CHECK(false,
77-
"Cannot construct a tensor with duplicate names unless they are tagged ",
78-
"and have different tags. Got names: ", names, ", offending names: (",
79-
*it, " and ", *dup, ").");
80-
dup = find_incompatible_name(dup + 1, names.end(), *it);
8161
}
8262
}
8363
}

aten/src/ATen/native/NamedTensor.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ static std::vector<int64_t> aligned_size(
6060
ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1;
6161
ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1;
6262
for (; idx >= 0 && dim >= 0; --idx) {
63-
TORCH_INTERNAL_ASSERT(!tensor_names[dim].is_tagged() && !aligned_names[idx].is_tagged(), "Tagged names NYI");
6463
if (tensor_names[dim] != aligned_names[idx]) {
6564
continue;
6665
}

aten/src/ATen/test/Dimname_test.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ TEST(DimnameTest, isValidIdentifier) {
3030
TEST(DimnameTest, wildcardName) {
3131
Dimname wildcard = Dimname::wildcard();
3232
ASSERT_EQ(wildcard.type(), NameType::WILDCARD);
33-
ASSERT_EQ(wildcard.full_name(), Symbol::dimname("*"));
33+
ASSERT_EQ(wildcard.symbol(), Symbol::dimname("*"));
3434
}
3535

3636
TEST(DimnameTest, createNormalName) {
3737
auto foo = Symbol::dimname("foo");
3838
auto dimname = Dimname::fromSymbol(foo);
3939
ASSERT_EQ(dimname.type(), NameType::NORMAL);
40-
ASSERT_EQ(dimname.full_name(), foo);
40+
ASSERT_EQ(dimname.symbol(), foo);
4141
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid")), c10::Error);
4242
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error);
4343
}
@@ -51,9 +51,8 @@ static void check_unify_and_match(
5151
auto result = at::unify(dimname1, dimname2);
5252
if (expected) {
5353
auto expected_result = Dimname::fromSymbol(Symbol::dimname(*expected));
54-
ASSERT_EQ(result->full_name(), expected_result.full_name());
54+
ASSERT_EQ(result->symbol(), expected_result.symbol());
5555
ASSERT_EQ(result->type(), expected_result.type());
56-
ASSERT_EQ(result->untagged_name(), expected_result.untagged_name());
5756
ASSERT_TRUE(match(dimname1, dimname2));
5857
} else {
5958
ASSERT_FALSE(result);

aten/src/ATen/test/NamedTensor_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static bool dimnames_equal(at::DimnameList names, at::DimnameList other) {
5252
for (auto i = 0; i < names.size(); i++) {
5353
const auto& name = names[i];
5454
const auto& other_name = other[i];
55-
if (name.type() != other_name.type() || name.full_name() != other_name.full_name()) {
55+
if (name.type() != other_name.type() || name.symbol() != other_name.symbol()) {
5656
return false;
5757
}
5858
}

torch/csrc/autograd/python_variable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ PyObject *THPVariable_get_names(THPVariable *self, void *unused)
337337
for (size_t i = 0; i < size; ++i) {
338338
PyObject* str = Py_None;
339339
if (dimnames[i].type() != at::NameType::WILDCARD) {
340-
str = THPUtils_packString(dimnames[i].full_name().toUnqualString());
340+
str = THPUtils_packString(dimnames[i].symbol().toUnqualString());
341341
if (!str) throw python_error();
342342
}
343343
PyTuple_SET_ITEM(tuple.get(), i, str);

0 commit comments

Comments
 (0)