Skip to content

Conversation

@jamesr66a
Copy link
Collaborator

No description provided.

@jamesr66a jamesr66a force-pushed the script_type_annotations2 branch from f88188c to 8a01c6b Compare July 13, 2018 00:32
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool -- I have a bunch of organization comments to keep the complexity of this change down. Right now it adds too many different pathways to the compiler and I have some suggestions about how to reduce that.

DefAndTypes(Def def, std::vector<TypePtr> arg_types, TypePtr return_type)
: def(std::move(def)), arg_types(arg_types), return_type(return_type) {}
Def def;
std::vector<TypePtr> arg_types;

This comment was marked as off-topic.

bool allow_varargs;
};

struct DefAndTypes {

This comment was marked as off-topic.

const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/
std::shared_ptr<SugaredValue> self /* if non-null, the first argument to each def, is bound to this value */
std::shared_ptr<SugaredValue> self, /* if non-null, the first argument to each def, is bound to this value */
bool pure_func=false /* If true, expect a single def which will become the 'forward' method on the module */

This comment was marked as off-topic.

mod = ScriptModule()
rcb = createResolutionCallback(_frames_up + 1)
ast = get_jit_ast(fn)
arg_types, ret_type = annotations.get_signature(fn, ast.num_params(), 0)

This comment was marked as off-topic.

graph = _script_graph(fn, _frames_up=_frames_up + 1)
mod = ScriptModule()
mod._create_method_from_graph('forward', graph)
_script_pure_function(mod, fn, _frames_up=_frames_up + 1)

This comment was marked as off-topic.

This comment was marked as off-topic.

auto& name = (*it).ident().name();
arguments.push_back({name, DynamicType::get()});
TypePtr arg_type = DynamicType::get();
if (def_and_types.arg_types.size()) {

This comment was marked as off-topic.

ast = get_jit_ast(fn)
arg_types, return_type = annotations.get_signature(fn, ast.num_params() - 1, 0)
# Dumb handling for `self`
if len(arg_types) == ast.num_params():

This comment was marked as off-topic.

torch._C.ScriptModule.__init__(self)
original_init(self, *args, **kwargs)
asts = [m.ast for m in methods]
defs = [torch._C.DefAndTypes(m.ast, m.arg_types, m.return_type) for m in methods]

This comment was marked as off-topic.

for(Def def : definitions) {
const std::string& name = def.name().name();
for(DefAndTypes def : definitions) {
const std::string& name = pure_func ? "forward" : def.def.name().name();

This comment was marked as off-topic.

std::vector<TypePtr> flattened_return_types;
if (def_and_types.return_type) {
if (def_and_types.return_type->kind() == TypeKind::TupleType) {
const auto &tuple_type_elmts = def_and_types.return_type->cast<TupleType>()->elements();

This comment was marked as off-topic.

wrap_list(r, std::move(body)));
}));
}))
.def("num_params", [](Def& self) {

This comment was marked as off-topic.

FunctionTable table;
JIT_ASSERT(definitions.size() == resolvers.size());
if (pure_func) {
JIT_ASSERT(definitions.size() == 1);

This comment was marked as off-topic.

bool allow_varargs;
};

struct DefAndTypes {

This comment was marked as off-topic.

This comment was marked as off-topic.



def _script_graph(fn, _frames_up=0):
def _compile_fn(fn, _frames_up=0):

This comment was marked as off-topic.

_jit_script_compile_pure_fn(mod, torch._C.DefAndTypes(ast, arg_types, ret_type), rcb)


def _script_graph(fn, _frames_up=0):

This comment was marked as off-topic.


m.def("_jit_script_compile", [](Def def, ResolutionCallback rcb) {
return compileFunction(def, pythonResolver(rcb));
m.def("_jit_script_compile_pure_fn", [](Module &m, DefAndTypes def, ResolutionCallback rcb) {

This comment was marked as off-topic.

@jamesr66a
Copy link
Collaborator Author

Superseded by #9547

@jamesr66a jamesr66a closed this Jul 18, 2018
facebook-github-bot pushed a commit that referenced this pull request Jul 26, 2018
Summary:
Supersedes #9405
Pull Request resolved: #9547

Reviewed By: zdevito

Differential Revision: D8900327

Pulled By: jamesr66a

fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
Summary:
Supersedes pytorch#9405
Pull Request resolved: pytorch#9547

Reviewed By: zdevito

Differential Revision: D8900327

Pulled By: jamesr66a

fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
Supersedes pytorch#9405
Pull Request resolved: pytorch#9547

Reviewed By: zdevito

Differential Revision: D8900327

Pulled By: jamesr66a

fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants