Skip to content

Commit 01e28fe

Browse files
author
eellison
committed
Update on "[JIT] Add support for named_modules()"
Fix for #28998 Differential Revision: [D18412561](https://our.internmc.facebook.com/intern/diff/D18412561) [ghstack-poisoned]
2 parents 602c255 + 7525356 commit 01e28fe

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,31 @@ std::shared_ptr<SugaredModuleDict> ModuleValue::getSugaredModuleDict(
284284
std::make_shared<SugaredTupleValue>(values));
285285
}
286286

287+
std::shared_ptr<SugaredValue> SugaredModuleDict::attr(
288+
const SourceRange& loc,
289+
Function& m,
290+
const std::string& field) {
291+
if (field == "keys") {
292+
return std::make_shared<ModuleDictMethod>(keys_, "keys");
293+
} else if (field == "values") {
294+
return std::make_shared<ModuleDictMethod>(modules_, "values");
295+
} else if (field == "items") {
296+
auto iterator = std::make_shared<IterableTree>();
297+
iterator->addChild(loc, m, keys_);
298+
iterator->addChild(loc, m, modules_);
299+
return std::make_shared<ModuleDictMethod>(iterator, "items");
300+
} else if (field == "named_modules") {
301+
auto iterator = std::make_shared<IterableTree>();
302+
std::vector<SugaredValuePtr> keys;
303+
std::vector<SugaredValuePtr> values;
304+
recurseThroughNestedModules(loc, m, keys, values, self_, "");
305+
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
306+
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
307+
return std::make_shared<ModuleDictMethod>(iterator, "named_modules");
308+
};
309+
TORCH_INTERNAL_ASSERT(false);
310+
}
311+
287312
// helper function for instantiating a SugaredValue from an IValue
288313
std::shared_ptr<SugaredValue> toSugaredValue(
289314
const IValue& v,
@@ -335,7 +360,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
335360
// 2. Special case: for module dicts we manually desugar items(), keys(),
336361
// values() calls into the appropriate method.
337362
if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
338-
if (field == "items" || field == "keys" || field != "values") {
363+
if (field == "items" || field == "keys" || field == "values") {
339364
return getSugaredModuleDict(loc, m)->attr(loc, m, field);
340365
}
341366
}

torch/csrc/jit/script/python_sugared_value.h

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -206,32 +206,7 @@ struct VISIBILITY_HIDDEN SugaredModuleDict : public SugaredValue {
206206
std::shared_ptr<SugaredValue> attr(
207207
const SourceRange& loc,
208208
Function& m,
209-
const std::string& field) override {
210-
if (field == "keys") {
211-
return std::make_shared<ModuleDictMethod>(keys_, "keys");
212-
} else if (field == "values") {
213-
return std::make_shared<ModuleDictMethod>(modules_, "values");
214-
} else if (field == "items") {
215-
auto iterator = std::make_shared<IterableTree>();
216-
iterator->addChild(loc, m, keys_);
217-
iterator->addChild(loc, m, modules_);
218-
return std::make_shared<ModuleDictMethod>(iterator, "items");
219-
} else if (field == "named_modules") {
220-
auto iterator = std::make_shared<IterableTree>();
221-
std::vector<SugaredValuePtr> keys;
222-
std::vector<SugaredValuePtr> values;
223-
224-
auto key_tuple = std::dynamic_pointer_cast<SugaredTupleValue>(keys_);
225-
auto values_tuple =
226-
std::dynamic_pointer_cast<SugaredTupleValue>(modules_);
227-
228-
recurseThroughNestedModules(loc, m, keys, values, self_, "");
229-
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
230-
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
231-
return std::make_shared<ModuleDictMethod>(iterator, "named_modules");
232-
};
233-
TORCH_INTERNAL_ASSERT(false);
234-
}
209+
const std::string& field) override;
235210

236211
SugaredValuePtr iter(const SourceRange& loc, Function& m) {
237212
return keys_;

0 commit comments

Comments
 (0)