Skip to content

Commit 3233a05

Browse files
zou3519facebook-github-bot
authored andcommitted
Add TensorNames::checkUnique, operator<< (TensorName) (#29124)
Summary: Pull Request resolved: #29124 TensorNames::checkUnique gives a nice error message if there are duplicate names. Adding operator<< on TensorName cleans up some code. A TensorName gets printed out as: "'H' (index 2 of ['N', 'C', 'H', 'W'])" for example. Test Plan: - New c++ tests. test with `build/bin/NamedTensor_test`. Differential Revision: D18311868 Pulled By: zou3519 fbshipit-source-id: 5be197dba227f0328b40d7f66e78fffefe4dbd00
1 parent 2c3c702 commit 3233a05

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

aten/src/ATen/TensorNames.cpp

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ const TensorName& TensorName::unify(const TensorName& other, const char* op_name
2626
const auto it = std::find(other.origin_.begin(), other.origin_.end(), name_);
2727
TORCH_CHECK(it == other.origin_.end(),
2828
op_name, ":",
29-
" Cannot match ", name_,
30-
" (at index ", origin_idx_, " of ", origin_, ")",
31-
" with ", other.name_,
32-
" (at index ", other.origin_idx_, " of ", other.origin_, ")",
33-
" because the latter names already has ", name_, ".",
29+
" Cannot match ", *this, " with ", other,
30+
" because the latter names already have ", name_, ".",
3431
" Are your tensors misaligned?");
3532
return *this;
3633
}
@@ -43,10 +40,8 @@ const TensorName& TensorName::unify(const TensorName& other, const char* op_name
4340
// unify(A, B)
4441
TORCH_CHECK(name_ == other.name_,
4542
op_name, ":",
46-
" Expected ", name_,
47-
" (at index ", origin_idx_, " of ", origin_, ")",
48-
" to match ", other.name_,
49-
" (at index ", other.origin_idx_, " of ", other.origin_, ")",
43+
" Expected ", *this,
44+
" to match ", other,
5045
" but they do not match.");
5146
return *this;
5247
}
@@ -89,6 +84,34 @@ void TensorNames::append(TensorName&& name) {
8984
names_.emplace_back(name);
9085
}
9186

87+
void TensorNames::checkUnique(const char* op_name) const {
88+
// O(N^2), but named tensors can have at most N = 64 dimensions, so this
89+
// doesn't matter unless benchmarking tells us it does. The alternative is
90+
// to create some sort of set data structure but the overhead of that
91+
// might dominate for small sizes.
92+
for (auto it = names_.begin(); it != names_.end(); ++it) {
93+
const auto name = it->toDimname();
94+
if (name.isWildcard()) continue;
95+
96+
auto dup = std::find_if(it + 1, names_.end(),
97+
[&](const TensorName& other) { return other.toDimname() == name; });
98+
TORCH_CHECK(dup == names_.end(),
99+
op_name, ": ",
100+
"Attempted to propagate dims ", *it, " and ", dup, " to the output, ",
101+
"but that would create a tensor with duplicate names ", toDimnameVec(),
102+
". Please rename your inputs with Tensor.rename to prevent this.");
103+
}
104+
}
105+
106+
// Let's say the TensorName represents 'C' in ['N', 'C', 'H, 'W'].
107+
// It should print like:
108+
// 'C' (index 1 of ['N', 'C', 'H', 'W'])
109+
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
110+
out << tensorname.name_ << " (index " << tensorname.origin_idx_ << " of ";
111+
out << tensorname.origin_ << ")";
112+
return out;
113+
}
114+
92115
std::vector<Dimname> TensorNames::toDimnameVec() const {
93116
std::vector<Dimname> result;
94117
result.reserve(names_.size());

aten/src/ATen/TensorNames.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ struct CAFFE2_API TensorName {
3434
origin_idx_(origin_idx),
3535
name_(origin[maybe_wrap_dim(origin_idx, origin.size())]) {}
3636

37+
// op_name is only used for error reporting.
3738
const TensorName& unify(const TensorName& other, const char* op_name) const;
3839
Dimname toDimname() const;
3940

4041
private:
4142
ArrayRef<Dimname> origin_;
4243
int origin_idx_;
4344
Dimname name_;
45+
CAFFE2_API friend std::ostream& operator<<(
46+
std::ostream& out,
47+
const TensorName& tensorname);
4448
};
4549

4650
using TensorNameVec = SmallVector<TensorName, 10>;
@@ -52,7 +56,9 @@ struct CAFFE2_API TensorNames {
5256
// `names`, NOT names[start:end], because the original tensor's names are `names`.
5357
explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
5458

59+
// op_name is only used for error reporting.
5560
TensorNames unifyFromRight(const TensorNames& other, const char* op_name) const;
61+
void checkUnique(const char* op_name) const;
5662

5763
void append(TensorName&& name);
5864
std::vector<Dimname> toDimnameVec() const;
@@ -62,6 +68,7 @@ struct CAFFE2_API TensorNames {
6268

6369
TensorNameVec names_;
6470
};
71+
6572
#endif
6673

6774
}} // namespace at::namedinference

aten/src/ATen/test/NamedTensor_test.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using at::Dimname;
1212
using at::DimnameList;
1313
using at::NamedTensorMeta;
1414
using at::Symbol;
15+
using at::namedinference::TensorName;
16+
using at::namedinference::TensorNames;
1517
using c10::guts::make_unique;
1618

1719
TEST(NamedTensorTest, defaultMetadata) {
@@ -209,5 +211,42 @@ TEST(NamedTensorTest, NoNamesGuard) {
209211
ASSERT_TRUE(at::NamesMode::is_enabled());
210212
}
211213

214+
static std::vector<Dimname> nchw() {
215+
auto N = dimnameFromString("N");
216+
auto C = dimnameFromString("C");
217+
auto H = dimnameFromString("H");
218+
auto W = dimnameFromString("W");
219+
return { N, C, H, W };
220+
}
221+
222+
TEST(NamedTensorTest, TensorNamePrint) {
223+
auto names = nchw();
224+
{
225+
auto N = TensorName(names, 0);
226+
ASSERT_EQ(
227+
c10::str(N),
228+
"'N' (index 0 of ['N', 'C', 'H', 'W'])");
229+
}
230+
{
231+
auto H = TensorName(names, 2);
232+
ASSERT_EQ(
233+
c10::str(H),
234+
"'H' (index 2 of ['N', 'C', 'H', 'W'])");
235+
}
236+
}
237+
238+
TEST(NamedTensorTest, TensorNamesCheckUnique) {
239+
auto names = nchw();
240+
{
241+
// smoke test to check that this doesn't throw
242+
TensorNames(names).checkUnique("op_name");
243+
}
244+
{
245+
std::vector<Dimname> nchh = { names[0], names[1], names[2], names[2] };
246+
auto tensornames = TensorNames(nchh);
247+
ASSERT_THROW(tensornames.checkUnique("op_name"), c10::Error);
248+
}
249+
}
250+
212251

213252
#endif

0 commit comments

Comments
 (0)