Skip to content

Commit 414a1fd

Browse files
swolchokpytorchmergebot
authored andcommitted
[PyTorch] Add IValue::IValue(std::vector<T>&&) ctors (#117769)
There are two IValue constructors that take `const std::vector<T>&`. Add moving variants to allow callers to save on reference counting. Differential Revision: [D52879065](https://our.internmc.facebook.com/intern/diff/D52879065/) Pull Request resolved: #117769 Approved by: https://github.com/suo, https://github.com/Skylion007
1 parent d45fd68 commit 414a1fd

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

aten/src/ATen/core/ivalue.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,8 @@ struct TORCH_API IValue final {
756756
IValue(at::ArrayRef<T> v);
757757
template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
758758
IValue(const std::vector<T>& v);
759+
template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
760+
IValue(std::vector<T>&& v);
759761
template <class T, size_t N>
760762
IValue(std::array<T, N> v);
761763

@@ -772,6 +774,9 @@ struct TORCH_API IValue final {
772774
IValue(at::OptionalArrayRef<T> v);
773775
template <class T, enable_if_symint<T> = nullptr>
774776
IValue(const std::vector<T>& v);
777+
template <class T, enable_if_symint<T> = nullptr>
778+
IValue(std::vector<T>&& v);
779+
775780

776781
template <class T>
777782
using enable_if_ilist_is_ivalue_constructible = std::enable_if_t<

aten/src/ATen/core/ivalue_inl.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,6 +2185,23 @@ template <class T, IValue::enable_if_symint<T>>
21852185
inline IValue::IValue(const std::vector<T>& v) : IValue() {
21862186
*this = IValue(at::ArrayRef<T>(v));
21872187
}
2188+
template <class T, IValue::enable_if_symint<T>>
2189+
inline IValue::IValue(std::vector<T>&& v) : IValue() {
2190+
auto vi = c10::asIntArrayRefSlowOpt(v);
2191+
if (vi.has_value()) {
2192+
// This list is entirely integers; ensure it is typed as
2193+
// an IntList so toIntList works
2194+
*this = IValue(*vi);
2195+
} else {
2196+
// This list has SymInts; type it as a SymInt
2197+
*this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2198+
auto list = to<c10::List<c10::SymInt>>();
2199+
list.reserve(v.size());
2200+
for (auto& e : v) {
2201+
list.push_back(std::move(e));
2202+
}
2203+
}
2204+
}
21882205
template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
21892206
inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
21902207
auto list = to<c10::List<T>>();
@@ -2193,6 +2210,22 @@ inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
21932210
list.push_back(e);
21942211
}
21952212
}
2213+
2214+
template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2215+
inline IValue::IValue(std::vector<T>&& v) : IValue(c10::List<T>()) {
2216+
auto list = to<c10::List<T>>();
2217+
list.reserve(v.size());
2218+
if constexpr (std::is_same_v<T, bool>) {
2219+
for (auto e : v) {
2220+
list.push_back(e);
2221+
}
2222+
} else {
2223+
for (auto& e : v) {
2224+
list.push_back(std::move(e));
2225+
}
2226+
}
2227+
}
2228+
21962229
template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
21972230
inline IValue::IValue(c10::OptionalArrayRef<T> v) : IValue() {
21982231
if (v.has_value()) {

aten/src/ATen/test/ivalue_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ TEST(IValueTest, Basic) {
4949
ASSERT_TRUE(dlist.isNone());
5050
dlist = IValue(c10::List<double>({3.4}));
5151
ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.4}));
52+
dlist = IValue(std::vector<double>({3.3, 3.2}));
53+
ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.3, 3.2}));
54+
IValue blist(std::vector<bool>{true, false});
55+
ASSERT_TRUE(blist.isList());
56+
const auto blistRef = blist.toListRef();
57+
ASSERT_EQ(blistRef.size(), 2);
58+
ASSERT_TRUE(blistRef[0].toBool());
59+
ASSERT_FALSE(blistRef[1].toBool());
5260
IValue the_list(
5361
at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
5462
ASSERT_EQ(foo.use_count(), 3);

0 commit comments

Comments
 (0)