Skip to content

Commit 98df978

Browse files
arktrailfacebook-github-bot
authored andcommitted
Impl for ParameterList (#41259)
Summary: This is a new PR for #40850, #40987 and #41206 unintentionally closed), as I have some issues for rebates for that one. Very sorry about that. And I have fixed the tests failed in that PR. This diff contains the implementation of C++ API for ParameterList from #25883. Refer to the Python API: https://github.com/pytorch/pytorch/blob/bc9e8af21875dafafe9bbd25c8f542b20b2e660f/torch/nn/modules/container.py#L376 Not sure about some naming difference between C++ API and Python API, like `append`, should it be called `push_back` Pull Request resolved: #41259 Test Plan: Add unit tests in this diff Differential Revision: D22495780 Pulled By: glaringlee fbshipit-source-id: 79ea3592db640f35477d445ecdaeafbdad814bec
1 parent fa15318 commit 98df978

File tree

4 files changed

+333
-0
lines changed

4 files changed

+333
-0
lines changed

test/cpp/api/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(TORCH_API_TEST_SOURCES
1616
${TORCH_API_TEST_DIR}/modulelist.cpp
1717
${TORCH_API_TEST_DIR}/modules.cpp
1818
${TORCH_API_TEST_DIR}/parameterdict.cpp
19+
${TORCH_API_TEST_DIR}/parameterlist.cpp
1920
${TORCH_API_TEST_DIR}/namespace.cpp
2021
${TORCH_API_TEST_DIR}/nn_utils.cpp
2122
${TORCH_API_TEST_DIR}/optim.cpp

test/cpp/api/parameterlist.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <torch/torch.h>
4+
5+
#include <algorithm>
6+
#include <memory>
7+
#include <vector>
8+
9+
#include <test/cpp/api/support.h>
10+
11+
using namespace torch::nn;
12+
using namespace torch::test;
13+
14+
struct ParameterListTest : torch::test::SeedingFixture {};
15+
16+
TEST_F(ParameterListTest, ConstructsFromSharedPointer) {
17+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
18+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
19+
torch::Tensor tc = torch::randn({1, 2});
20+
ASSERT_TRUE(ta.requires_grad());
21+
ASSERT_FALSE(tb.requires_grad());
22+
ParameterList list(ta, tb, tc);
23+
ASSERT_EQ(list->size(), 3);
24+
}
25+
26+
TEST_F(ParameterListTest, isEmpty) {
27+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
28+
ParameterList list;
29+
ASSERT_TRUE(list->is_empty());
30+
list->append(ta);
31+
ASSERT_FALSE(list->is_empty());
32+
ASSERT_EQ(list->size(), 1);
33+
}
34+
35+
TEST_F(ParameterListTest, PushBackAddsAnElement) {
36+
ParameterList list;
37+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
38+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
39+
torch::Tensor tc = torch::randn({1, 2});
40+
torch::Tensor td = torch::randn({1, 2, 3});
41+
ASSERT_EQ(list->size(), 0);
42+
ASSERT_TRUE(list->is_empty());
43+
list->append(ta);
44+
ASSERT_EQ(list->size(), 1);
45+
list->append(tb);
46+
ASSERT_EQ(list->size(), 2);
47+
list->append(tc);
48+
ASSERT_EQ(list->size(), 3);
49+
list->append(td);
50+
ASSERT_EQ(list->size(), 4);
51+
}
52+
TEST_F(ParameterListTest, ForEachLoop) {
53+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
54+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
55+
torch::Tensor tc = torch::randn({1, 2});
56+
torch::Tensor td = torch::randn({1, 2, 3});
57+
ParameterList list(ta, tb, tc, td);
58+
std::vector<torch::Tensor> params = {ta, tb, tc, td};
59+
ASSERT_EQ(list->size(), 4);
60+
int idx = 0;
61+
for (const auto& pair : *list) {
62+
ASSERT_TRUE(
63+
torch::all(torch::eq(pair.value(), params[idx++])).item<bool>());
64+
}
65+
}
66+
67+
TEST_F(ParameterListTest, AccessWithAt) {
68+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
69+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
70+
torch::Tensor tc = torch::randn({1, 2});
71+
torch::Tensor td = torch::randn({1, 2, 3});
72+
std::vector<torch::Tensor> params = {ta, tb, tc, td};
73+
74+
ParameterList list;
75+
for (auto& param : params) {
76+
list->append(param);
77+
}
78+
ASSERT_EQ(list->size(), 4);
79+
80+
// returns the correct module for a given index
81+
for (size_t i = 0; i < params.size(); ++i) {
82+
ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item<bool>());
83+
}
84+
85+
for (size_t i = 0; i < params.size(); ++i) {
86+
ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item<bool>());
87+
}
88+
89+
// throws for a bad index
90+
ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range");
91+
ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range");
92+
ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range");
93+
}
94+
95+
TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) {
96+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
97+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
98+
torch::Tensor tc = torch::randn({1, 2});
99+
torch::Tensor td = torch::randn({1, 2, 3});
100+
torch::Tensor te = torch::randn({1, 2});
101+
torch::Tensor tf = torch::randn({1, 2, 3});
102+
ParameterList a(ta, tb);
103+
ParameterList b(tc, td);
104+
a->extend(*b);
105+
106+
ASSERT_EQ(a->size(), 4);
107+
ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item<bool>());
108+
ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item<bool>());
109+
ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item<bool>());
110+
ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item<bool>());
111+
112+
ASSERT_EQ(b->size(), 2);
113+
ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
114+
ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
115+
116+
std::vector<torch::Tensor> c = {te, tf};
117+
b->extend(c);
118+
119+
ASSERT_EQ(b->size(), 4);
120+
ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
121+
ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
122+
ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item<bool>());
123+
ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item<bool>());
124+
}
125+
126+
TEST_F(ParameterListTest, PrettyPrintParameterList) {
127+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
128+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
129+
torch::Tensor tc = torch::randn({1, 2});
130+
ParameterList list(ta, tb, tc);
131+
ASSERT_EQ(
132+
c10::str(list),
133+
"torch::nn::ParameterList(\n"
134+
"(0): Parameter containing: [Float of size [1, 2]]\n"
135+
"(1): Parameter containing: [Float of size [1, 2]]\n"
136+
"(2): Parameter containing: [Float of size [1, 2]]\n"
137+
")");
138+
}
139+
140+
TEST_F(ParameterListTest, IncrementAdd) {
141+
torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
142+
torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
143+
torch::Tensor tc = torch::randn({1, 2});
144+
torch::Tensor td = torch::randn({1, 2, 3});
145+
torch::Tensor te = torch::randn({1, 2});
146+
torch::Tensor tf = torch::randn({1, 2, 3});
147+
ParameterList listA(ta, tb, tc);
148+
ParameterList listB(td, te, tf);
149+
std::vector<torch::Tensor> tensors{ta, tb, tc, td, te, tf};
150+
int idx = 0;
151+
*listA += *listB;
152+
ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item<bool>());
153+
ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item<bool>());
154+
ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item<bool>());
155+
ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item<bool>());
156+
ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item<bool>());
157+
ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item<bool>());
158+
for (const auto& P : listA->named_parameters(false))
159+
ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item<bool>());
160+
161+
ASSERT_EQ(idx, 6);
162+
}

torch/csrc/api/include/torch/nn/modules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/nn/modules/container/named_any.h>
1111
#include <torch/nn/modules/container/sequential.h>
1212
#include <torch/nn/modules/container/parameterdict.h>
13+
#include <torch/nn/modules/container/parameterlist.h>
1314

1415
// Layers
1516
#include <torch/nn/modules/adaptive.h>
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#pragma once
2+
3+
#include <torch/nn/cloneable.h>
4+
#include <torch/nn/module.h>
5+
6+
#include <vector>
7+
8+
namespace torch {
9+
namespace nn {
10+
class ParameterListImpl : public Cloneable<ParameterListImpl> {
11+
public:
12+
using Iterator = typename std::vector<
13+
OrderedDict<std::string, torch::Tensor>::Item>::iterator;
14+
using ConstIterator = typename std::vector<
15+
OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;
16+
17+
ParameterListImpl() = default;
18+
19+
/// Constructs the `ParameterList` from a variadic list of ParameterList.
20+
template <typename... Tensors>
21+
explicit ParameterListImpl(Tensors&&... params) {
22+
parameters_.reserve(sizeof...(Tensors));
23+
push_back_var(std::forward<Tensors>(params)...);
24+
}
25+
26+
template <typename... Tensors>
27+
explicit ParameterListImpl(const Tensors&... params) {
28+
parameters_.reserve(sizeof...(Tensors));
29+
push_back_var(std::forward<Tensors>(params)...);
30+
}
31+
32+
/// `reset()` is empty for `ParameterList`, since it does not have parameters
33+
/// of its own.
34+
void reset() override {}
35+
36+
/// Pretty prints the `ParameterList` module into the given `stream`.
37+
void pretty_print(std::ostream& stream) const override {
38+
stream << "torch::nn::ParameterList(" << std::endl;
39+
for (const auto& pair : parameters_) {
40+
stream << "(" << pair.key() << ")"
41+
<< ": Parameter containing: [" << pair.value().scalar_type()
42+
<< " of size " << pair.value().sizes() << "]";
43+
;
44+
stream << std::endl;
45+
}
46+
stream << ")";
47+
}
48+
49+
/// push the a given parameter at the end of the list
50+
void append(torch::Tensor&& param) {
51+
bool requires_grad = param.requires_grad();
52+
register_parameter(
53+
c10::to_string(parameters_.size()), std::move(param), requires_grad);
54+
}
55+
56+
/// push the a given parameter at the end of the list
57+
void append(const torch::Tensor& param) {
58+
bool requires_grad = param.requires_grad();
59+
register_parameter(
60+
c10::to_string(parameters_.size()), std::move(param), requires_grad);
61+
}
62+
63+
/// push the a given parameter at the end of the list
64+
/// And the key of the pair will be discarded, only the value
65+
/// will be added into the `ParameterList`
66+
void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
67+
register_parameter(
68+
c10::to_string(parameters_.size()),
69+
pair.value(),
70+
pair.value().requires_grad());
71+
}
72+
73+
/// extend parameters from a container to the end of the list
74+
template <typename Container>
75+
void extend(const Container& container) {
76+
for (const auto& param : container) {
77+
append(param);
78+
}
79+
}
80+
81+
/// Returns an iterator to the start of the ParameterList
82+
/// the iterator returned will be type of `OrderedDict<std::string,
83+
/// torch::Tensor>::Item`
84+
Iterator begin() {
85+
return parameters_.begin();
86+
}
87+
88+
/// Returns a const iterator to the start of the ParameterList
89+
/// the iterator returned will be type of `OrderedDict<std::string,
90+
/// torch::Tensor>::Item`
91+
ConstIterator begin() const {
92+
return parameters_.begin();
93+
}
94+
95+
/// Returns an iterator to the end of the ParameterList
96+
/// the iterator returned will be type of `OrderedDict<std::string,
97+
/// torch::Tensor>::Item`
98+
Iterator end() {
99+
return parameters_.end();
100+
}
101+
102+
/// Returns a const iterator to the end of the ParameterList
103+
/// the iterator returned will be type of `OrderedDict<std::string,
104+
/// torch::Tensor>::Item`
105+
ConstIterator end() const {
106+
return parameters_.end();
107+
}
108+
109+
/// Returns the value associated with the given `key`. Throws an exception if
110+
/// no such key is stored in the `ParameterList`. Check contains(key) before
111+
/// for a non-throwing way of access
112+
at::Tensor& at(size_t idx) {
113+
TORCH_CHECK(idx < size(), "Index out of range");
114+
return parameters_[c10::to_string(idx)];
115+
}
116+
117+
/// Returns the value associated with the given `key`. Throws an exception if
118+
/// no such key is stored in the `ParameterList`. Check contains(key) before
119+
/// for a non-throwing way of access
120+
const at::Tensor& at(size_t idx) const {
121+
TORCH_CHECK(idx < size(), "Index out of range");
122+
return parameters_[c10::to_string(idx)];
123+
}
124+
125+
/// Returns the value associated with the given `key`. Throws an exception if
126+
/// no such key is stored in the `ParameterList`. Check contains(key) before
127+
/// for a non-throwing way of access
128+
at::Tensor& operator[](size_t idx) {
129+
return at(idx);
130+
}
131+
132+
/// Returns the value associated with the given `key`. Throws an exception if
133+
/// no such key is stored in the `ParameterList`. Check contains(key) before
134+
/// for a non-throwing way of access
135+
const at::Tensor& operator[](size_t idx) const {
136+
return at(idx);
137+
}
138+
139+
/// Return the size of the ParameterList
140+
size_t size() const noexcept {
141+
return parameters_.size();
142+
}
143+
/// True if the ParameterList is empty
144+
bool is_empty() const noexcept {
145+
return parameters_.is_empty();
146+
}
147+
148+
/// Overload the +=, so that two ParameterList could be incrementally added
149+
template <typename Container>
150+
Container& operator+=(const Container& other) {
151+
extend(other);
152+
return *this;
153+
}
154+
155+
private:
156+
template <typename Head, typename... Tail>
157+
void push_back_var(Head&& head, Tail&&... tail) {
158+
append(std::forward<Head>(head));
159+
// Recursively calls this method, until the parameter pack only thas this
160+
// entry left. Then calls `push_back()` a final time (above).
161+
push_back_var(std::forward<Tail>(tail)...);
162+
}
163+
164+
/// The base case, when the list of modules is empty.
165+
void push_back_var() {}
166+
};
167+
TORCH_MODULE(ParameterList);
168+
} // namespace nn
169+
} // namespace torch

0 commit comments

Comments
 (0)