Skip to content

Conversation

@xzhu1900
Copy link
Contributor

When dealing with large scale dataset, it is handy if we can save the dataset status and resume later. Especially in cases where some unexpected crash happens, user don't need to start over the whole dataset from begining. Instead, they can reload it from the last checkpoint.

This change adds support for checkpoint save/load logic in ChunkDataset.

On ChunkDataset construction, user can specify a file name from which to load the checkpoint. If it is empty, default to start from fresh; otherwise the ChunkDataset will 'fast forward' the chunk sampler to the corresponding checkpoint.

The user can also call ChunkDataset::save() to serialize current status to a file, which can be used later.

@pytorchbot pytorchbot added the module: cpp Related to C++ API label Jun 18, 2019
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems very surprising to me that resuming from a file requires passing in an argument to a constructor? That doesn't seem like a good design decision. The interface we should be aiming for is to do this:

chunk_dataset.save("...");
// then, in other part of the code this is the only thing you have to do
torch::load(chunk_dataset, "...");

We have no other class that would follow the same interface so I'd rather not make this exception, especially that it doesn't seem to be a step in a good direction


// After the checkpoint is loaded, mark the boolean to false to prevent future loading.
load_checkpoint_ = false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh why do we need this deferred loading business? Because when we start iterating in a data loader it always calls reset() at the beginning of an epoch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because when we start iterating in a data loader it always calls reset() at the beginning of an epoch?

That's exactly why we need to load here --
if 'load' option is set, the chunk_sampler_ needs to be resumed from the file for the first epoch when reset() is called by data loader. This way, ChunkDataset read chunks starting from the last saved point instead from the beginning. This is the core load logic. After the first epoch, any following epoch should just reset chunk_sampler_ as normal.

Please let me know if you have any questions about this functionality.

@xzhu1900
Copy link
Contributor Author

It seems very surprising to me that resuming from a file requires passing in an argument to a constructor? That doesn't seem like a good design decision. The interface we should be aiming for is to do this:
chunk_dataset.save("...");
// then, in other part of the code this is the only thing you have to do
torch::load(chunk_dataset, "...");
We have no other class that would follow the same interface so I'd rather not make this exception, especially that it doesn't seem to be a step in a good direction

This case is actually different than any existing cases, and also different than the example you mentioned above.

In practice, when save/load is used, users usually save with a frequency (every fixed amount of mini-batches). The training job keeps running and some unexpected crash/hang happens that terminates the job. Then, to recover, user submits a new job, with the 'load' option to construct a new ChunkDataset from the previously saved file and then continue. That is the practical use case here.

Given that the 'load' happens at construction time, instead of in the middle of training, I think this is a reasonable design.

@apaszke
Copy link
Contributor

apaszke commented Jun 21, 2019

I disagree. This is not the API we have in Python, and following this logic we should have a "load-from-file" argument in a constructor of every single object. If you support restarting the training from a checkpoint, then you should have a conditional along the lines of:

if (!checkpoint_path.empty()) {
  dataset = torch::load(checkpoint_path);
} else {
  dataset = ChunkDataset(...);
}

Otherwise it will start creeping over the whole interface.

@xzhu1900
Copy link
Contributor Author

I disagree. This is not the API we have in Python, and following this logic we should have a "load-from-file" argument in a constructor of every single object. If you support restarting the training from a checkpoint, then you should have a conditional along the lines of:
if (!checkpoint_path.empty()) {
dataset = torch::load(checkpoint_path);
} else {
dataset = ChunkDataset(...);
}
Otherwise it will start creeping over the whole interface.

What's inside torch::load(checkpoint_path)? It is a constructor that has exactly the same code change in this PR, plus a whole lot more redudent type check and parameter extraction that the user already know. So I am not convinced that this is the right way to do.

following this logic we should have a "load-from-file" argument in a constructor of every single object

This is not true. ChunkDataset inherits from StatefulDataset. So it is very natural for a StatefulDatset to resume a state from construction. For the rest objects, it should be totally stateless thus no need to have a 'load-from-file' argument.

@xzhu1900
Copy link
Contributor Author

xzhu1900 commented Jun 21, 2019

As for your proposed approach:

dataset = torch::load(checkpoint_path);

To make it work here, we need to re-construct a ChunkDataset from a serialized file.

We need deserialize constructor parameter which is the information below:

  1. ChunkDatasetOptions
  2. chunk sampler and example sampler
  3. ChunkReader

both 2 and 3 are problematic. For example, without knowing the sampler's constructor, we have no idea how to construct an instance and pass it to ChunkDataset's constructor. This is the deal breaker here. The chunk sampler for example can be any customized sampler that has a special constructor. Those information cannot be serialized out and re-used.

We can add restriction saying torch::load only supports samplers known to pytorch. But this will basically chase away users who wants to use their own sampler. Having such an imposition doesn't seem like the right way to design this feature.

@apaszke
Copy link
Contributor

apaszke commented Jun 22, 2019

Well, ok. For me it's a problem with the design of torch::load. Depending on operator>> is annoying because it assumes that you have to have be able to construct a value (using the public interface!) before you can even start the loading. How about this:

We add another instance of the torch::load template, which instead of using operator>> will try to do Value::fromFile. Hence, any class can implement a static method that deals with serialization. This way, you can have a private no-argument no-op constructor that lets you preload the state manually based on whatever you read from the file.

@xzhu1900
Copy link
Contributor Author

xzhu1900 commented Jun 24, 2019

Well, ok. For me it's a problem with the design of torch::load. Depending on operator>> is annoying because it assumes that you have to have be able to construct a value (using the public interface!) before you can even start the loading. How about this:

We add another instance of the torch::load template, which instead of using operator>> will try to do Value::fromFile. Hence, any class can implement a static method that deals with serialization. This way, you can have a private no-argument no-op constructor that lets you preload the state manually based on whatever you read from the file.

Having a static Value::fromFile has the same problem as I mentioned before -- If Value is of type ChunkDataset, we cannot construct an instance of ChunkDataset from a file because we don't know the sampler and ChunkReader's constructor. I think this is a deal breaker for using any static ::load method.

@soumith soumith added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 25, 2019
@soumith
Copy link
Contributor

soumith commented Jun 25, 2019

@apaszke if needed, both of you talk on a high-bandwidth channel (slack, messenger). please resolve this soon.

@apaszke
Copy link
Contributor

apaszke commented Jun 25, 2019

Why don't we know the sampler? You're both saving and loading the chunk sampler as part of the newly added code, so I'm sorry but I don't understand the problem. Yeah, we should have a VC (let's talk on Slack).

@xzhu1900
Copy link
Contributor Author

Synced on Slack. Based on the discussion, I have pushed a new iteration to the PR.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, but I have three more comments:

  1. Can we please make the save and load private methods and operator<< and operator>> friends of this class? Alternatively just make them friends and remove the methods.
  2. I don't think there's a single test that would load and save a dataset from the file. You only keep saving the dataset and loading a sampler from the same tempfile, or the other way around. This doesn't seem like a good test, because it's not the desired usage. The fact that those objects have the same serialization format on-disk is a coincidence to some degree and should not be depended on.
  3. There's other state in the ChunkDataset like the batch size. Why isn't this saved?

@xzhu1900
Copy link
Contributor Author

  1. Can we please make the save and load private methods and operator<< and operator>> friends of this class? Alternatively just make them friends and remove the methods.

This is the existing pattern for other classes too. So let's keep it consistent.

  1. I don't think there's a single test that would load and save a dataset from the file. You only keep saving the dataset and loading a sampler from the same tempfile, or the other way around. This doesn't seem like a good test, because it's not the desired usage. The fact that those objects have the same serialization format on-disk is a coincidence to some degree and should not be depended on.

Updated the test with comment, explaining this is only for test purpose only.
The chunk sampler inside ChunkDataset is used in a separate thread pool other than the main thread. Thus it is very hard to accurately estimate its status when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of testing, manually control the chunk sampler by calling the sampler's save/load method for value validation. In real user case, the user should still use matching ChunkDataset::save and ChunkDataset::load method.

@jaliyae jaliyae self-assigned this Jun 26, 2019
@jaliyae jaliyae self-requested a review June 26, 2019 20:36
@jaliyae jaliyae removed their assignment Jun 26, 2019
Copy link
Contributor

@jaliyae jaliyae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@soumith
Copy link
Contributor

soumith commented Jun 26, 2019

@pytorchbot merge this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ailzhang merged this pull request in f39b662.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpp Related to C++ API open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants