Skip to content

Commit 74bc5c4

Browse files
committed
Update on "[jit] Implement more of of the nn.Module API"
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
2 parents 2a0ce49 + fff4f16 commit 74bc5c4

File tree

298 files changed

+5541
-4308
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

298 files changed

+5541
-4308
lines changed

.circleci/scripts/build_android_gradle.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,16 @@ else
5656
$GRADLE_PATH -p ~/workspace/android/ assembleRelease
5757
fi
5858

59+
60+
find . -type f -name "*.a" -exec ls -lh {} \;
61+
62+
while IFS= read -r -d '' file
63+
do
64+
echo
65+
echo "$file"
66+
ls -lah "$file"
67+
zipinfo -l "$file"
68+
done < <(find . -type f -name '*.aar' -print0)
69+
5970
find . -type f -name *aar -print | xargs tar cfvz ~/workspace/android/artifacts.tgz
71+

aten/src/ATen/Context.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,22 @@ CAFFE2_API Allocator* getCPUAllocator();
142142

143143
static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) {
144144
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
145-
p, s);
145+
p, s, /*is_variable*/false);
146146
}
147147

148148
static inline DeprecatedTypeProperties& CPU(ScalarType s) {
149149
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
150-
Backend::CPU, s);
150+
Backend::CPU, s, /*is_variable*/false);
151151
}
152152

153153
static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
154154
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
155-
Backend::CUDA, s);
155+
Backend::CUDA, s, /*is_variable*/false);
156156
}
157157

158158
static inline DeprecatedTypeProperties& HIP(ScalarType s) {
159159
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
160-
Backend::HIP, s);
160+
Backend::HIP, s, /*is_variable*/false);
161161
}
162162

163163
static inline bool hasCUDA() {

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -651,18 +651,6 @@
651651
- bool largest
652652
- bool sorted
653653
]]
654-
[[
655-
name: _th_abs
656-
cname: abs
657-
backends:
658-
- CUDA
659-
variants: function
660-
return: argument 0
661-
arguments:
662-
- arg: THTensor* result
663-
output: True
664-
- THTensor* self
665-
]]
666654
[[
667655
name: _th_exp
668656
cname: exp

aten/src/ATen/InitialTensorOptions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace at {
99
// NOTE: this is not a stable API.
1010
inline TensorOptions initialTensorOptions() {
1111
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided)
12-
.requires_grad(false);
12+
.requires_grad(false).is_variable(false);
1313
}
1414

1515
}

aten/src/ATen/NamedTensorUtils.cpp

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ std::vector<Dimname> unify_from_right(
111111
return result;
112112
}
113113

114-
115114
namespace namedinference {
116115

117116
static std::bitset<dim_bitset_size>
@@ -128,51 +127,41 @@ static void assert_names_equal(DimnameList a, DimnameList b) {
128127
". Please rename the out tensor's dims with `Tensor.rename`.");
129128
}
130129

131-
void propagate_names(TensorImpl* result, optional<DimnameList> names) {
132-
if (!impl::has_names(result) && !names.has_value()) {
133-
return;
134-
}
135-
if (!impl::has_names(result)) {
136-
impl::internal_set_names_inplace(result, names);
137-
return;
138-
}
139-
assert_names_equal(
140-
impl::get_names(result),
141-
names.value_or(default_names(result->dim())));
142-
}
143-
144-
void propagate_names(
145-
Tensor& result,
146-
optional<std::vector<Dimname>>&& maybe_names,
130+
Tensor& propagate_names_if_nonempty(Tensor& result,
131+
DimnameList maybe_names,
147132
bool validate_names) {
148-
propagate_names(result.unsafeGetTensorImpl(), std::move(maybe_names), validate_names);
133+
propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
134+
return result;
149135
}
150136

151-
void propagate_names(
152-
TensorImpl* result,
153-
optional<std::vector<Dimname>>&& maybe_names,
137+
TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
138+
DimnameList maybe_names,
154139
bool validate_names) {
155-
if (!maybe_names) {
156-
propagate_names(result, nullopt);
157-
return;
158-
}
159-
propagate_names(result, std::move(maybe_names.value()), validate_names);
160-
}
161-
162-
void propagate_names(TensorImpl* result, std::vector<Dimname>&& names, bool validate_names) {
163-
if (!impl::has_names(result)) {
164-
impl::internal_set_names_inplace(result, std::move(names), validate_names);
165-
return;
140+
if (maybe_names.empty()) {
141+
return result;
166142
}
167-
assert_names_equal(impl::get_names(result), names);
143+
return propagate_names(result, maybe_names, validate_names);
168144
}
169145

170-
void propagate_names(Tensor& result, optional<DimnameList> names) {
171-
propagate_names(result.unsafeGetTensorImpl(), names);
146+
Tensor& propagate_names(Tensor& result, DimnameList names, bool validate_names) {
147+
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
148+
return result;
172149
}
173150

174-
void propagate_names(Tensor& result, std::vector<Dimname>&& names, bool validate_names) {
175-
propagate_names(result.unsafeGetTensorImpl(), std::move(names), validate_names);
151+
TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) {
152+
if (result->dim() > 0) {
153+
TORCH_INTERNAL_ASSERT(
154+
!names.empty(),
155+
"propagate_names: passed in empty names to propagate to result with",
156+
" shape ", result->sizes(), ". Empty names means that name inference did",
157+
"not occur; use `propagate_names_if_nonempty` instead of `propagate_names`.");
158+
}
159+
if (!impl::has_names(result)) {
160+
impl::internal_set_names_inplace(result, names, validate_names);
161+
} else {
162+
assert_names_equal(impl::get_names(result), names);
163+
}
164+
return result;
176165
}
177166

178167
void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
@@ -188,7 +177,7 @@ void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef exclu
188177
if (excluded_idxs.size() == 1) {
189178
std::vector<Dimname> outnames = src_names.vec();
190179
outnames.erase(outnames.begin() + maybe_wrap_dim(excluded_idxs[0], src_dim));
191-
propagate_names(result, std::move(outnames), /*validate_names=*/false);
180+
propagate_names(result, outnames);
192181
return;
193182
}
194183

@@ -200,7 +189,7 @@ void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef exclu
200189
outnames.push_back(src_names[dim]);
201190
}
202191
}
203-
propagate_names(result, std::move(outnames), /*validate_names=*/false);
192+
propagate_names(result, outnames);
204193
}
205194

206195
void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
@@ -223,12 +212,15 @@ void propagate_names(TensorImpl* result, TensorImpl* src) {
223212
if (result == src) {
224213
return;
225214
}
226-
propagate_names(result, impl::get_opt_names(src));
215+
if (!impl::has_names(result) && !impl::has_names(src)) {
216+
return;
217+
}
218+
propagate_names(result, impl::get_names(src));
227219
}
228220

229-
optional<std::vector<Dimname>> compute_squeeze_outnames(const Tensor& tensor) {
221+
std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) {
230222
if (!tensor.has_names()) {
231-
return nullopt;
223+
return {};
232224
}
233225
std::vector<Dimname> outnames;
234226
auto tensor_names = tensor.names();
@@ -371,7 +363,7 @@ void propagate_names_for_addmv(
371363
}
372364
auto mv_outnames = compute_matmul_outnames(impl::get_names(mat), impl::get_names(vec));
373365
auto add_outnames = unify_from_right(mv_outnames, impl::get_names(bias));
374-
propagate_names(result, std::move(add_outnames), /*validate_names=*/false);
366+
propagate_names(result, add_outnames);
375367
}
376368

377369
void propagate_names_for_addmm(
@@ -385,7 +377,7 @@ void propagate_names_for_addmm(
385377
}
386378
auto mm_outnames = compute_matmul_outnames(impl::get_names(m1), impl::get_names(m2));
387379
auto add_outnames = unify_from_right(mm_outnames, impl::get_names(bias));
388-
propagate_names(result, std::move(add_outnames), /*validate_names=*/false);
380+
propagate_names(result, add_outnames);
389381
}
390382

391383
void check_names_for_dot(
@@ -415,24 +407,24 @@ void propagate_names_for_expand(Tensor& result, const Tensor& self) {
415407
self.opt_names()->begin(),
416408
self.opt_names()->end(),
417409
outnames.begin() + result_dim - self.dim());
418-
propagate_names(result, std::move(outnames), /*validate_names=*/false);
410+
propagate_names(result, outnames);
419411
}
420412

421-
optional<std::vector<Dimname>> compute_broadcast_outnames(
413+
std::vector<Dimname> compute_broadcast_outnames(
422414
const Tensor& self,
423415
const Tensor& other) {
424416
if (!self.has_names() && !other.has_names()) {
425-
return nullopt;
417+
return {};
426418
}
427419
return unify_from_right(self.names(), other.names());
428420
}
429421

430-
optional<std::vector<Dimname>> broadcast_to_outnames(
422+
std::vector<Dimname> broadcast_to_outnames(
431423
const Tensor& tensor,
432424
const Tensor& reference_tensor,
433425
const char* op_name) {
434426
if (!tensor.has_names() && !reference_tensor.has_names()) {
435-
return nullopt;
427+
return {};
436428
}
437429
auto reference_names = reference_tensor.names();
438430
auto tensor_names = tensor.names();
@@ -445,9 +437,9 @@ optional<std::vector<Dimname>> broadcast_to_outnames(
445437
return unify_from_right(reference_names, tensor_names);
446438
}
447439

448-
optional<std::vector<Dimname>> compute_cat_outnames(TensorList tensors) {
440+
std::vector<Dimname> compute_cat_outnames(TensorList tensors) {
449441
if (!at::has_names(tensors)) {
450-
return nullopt;
442+
return {};
451443
}
452444
std::vector<Dimname> result;
453445
for (const auto& tensor : tensors) {
@@ -461,20 +453,20 @@ optional<std::vector<Dimname>> compute_cat_outnames(TensorList tensors) {
461453
return result;
462454
}
463455

464-
optional<std::vector<Dimname>> compute_matmul_outnames(
456+
std::vector<Dimname> compute_matmul_outnames(
465457
const Tensor& self,
466458
const Tensor& other) {
467459
if (!self.has_names() && !other.has_names()) {
468-
return nullopt;
460+
return {};
469461
}
470462
return compute_matmul_outnames(self.names(), other.names());
471463
}
472464

473-
optional<std::vector<Dimname>> compute_cdist_outnames(
465+
std::vector<Dimname> compute_cdist_outnames(
474466
const Tensor& self,
475467
const Tensor& other) {
476468
if (!self.has_names() && !other.has_names()) {
477-
return nullopt;
469+
return {};
478470
}
479471
const auto self_names = self.names();
480472
const auto other_names = other.names();
@@ -496,24 +488,24 @@ optional<std::vector<Dimname>> compute_cdist_outnames(
496488
return result.toDimnameVec();
497489
}
498490

499-
optional<std::vector<Dimname>> compute_bmm_outnames(
491+
std::vector<Dimname> compute_bmm_outnames(
500492
Tensor& result,
501493
const Tensor& self,
502494
const Tensor& other) {
503495
if (!result.has_names() && !self.has_names() && !other.has_names()) {
504-
return nullopt;
496+
return {};
505497
}
506498
return compute_matmul_outnames(self.names(), other.names());
507499
}
508500

509-
optional<std::vector<Dimname>> compute_baddbmm_outnames(
501+
std::vector<Dimname> compute_baddbmm_outnames(
510502
TensorImpl* result,
511503
TensorImpl* batch1,
512504
TensorImpl* batch2,
513505
TensorImpl* bias) {
514506
if (!impl::has_names(result) && !impl::has_names(batch1) &&
515507
!impl::has_names(batch2) && !impl::has_names(bias)) {
516-
return nullopt;
508+
return {};
517509
}
518510
auto bmm_names = compute_matmul_outnames(
519511
impl::get_names(batch1), impl::get_names(batch2));

0 commit comments

Comments
 (0)