-
Notifications
You must be signed in to change notification settings - Fork 26.3k
ChunkDataset checkpoint support #21889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems very surprising to me that resuming from a file requires passing in an argument to a constructor? That doesn't seem like a good design decision. The interface we should be aiming for is to do this:
chunk_dataset.save("...");
// then, in other part of the code this is the only thing you have to do
torch::load(chunk_dataset, "...");We have no other class that would follow the same interface so I'd rather not make this exception, especially that it doesn't seem to be a step in a good direction
|
|
||
| // After the checkpoint is loaded, mark the boolean to false to prevent future loading. | ||
| load_checkpoint_ = false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh why do we need this deferred loading business? Because when we start iterating in a data loader it always calls reset() at the beginning of an epoch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because when we start iterating in a data loader it always calls reset() at the beginning of an epoch?
That's exactly why we need to load here --
if 'load' option is set, the chunk_sampler_ needs to be resumed from the file for the first epoch when reset() is called by data loader. This way, ChunkDataset read chunks starting from the last saved point instead from the beginning. This is the core load logic. After the first epoch, any following epoch should just reset chunk_sampler_ as normal.
Please let me know if you have any questions about this functionality.
This case is actually different than any existing cases, and also different than the example you mentioned above. In practice, when save/load is used, users usually save with a frequency (every fixed amount of mini-batches). The training job keeps running and some unexpected crash/hang happens that terminates the job. Then, to recover, user submits a new job, with the 'load' option to construct a new ChunkDataset from the previously saved file and then continue. That is the practical use case here. Given that the 'load' happens at construction time, instead of in the middle of training, I think this is a reasonable design. |
|
I disagree. This is not the API we have in Python, and following this logic we should have a "load-from-file" argument in a constructor of every single object. If you support restarting the training from a checkpoint, then you should have a conditional along the lines of: if (!checkpoint_path.empty()) {
dataset = torch::load(checkpoint_path);
} else {
dataset = ChunkDataset(...);
}Otherwise it will start creeping over the whole interface. |
What's inside torch::load(checkpoint_path)? It is a constructor that has exactly the same code change in this PR, plus a whole lot more redudent type check and parameter extraction that the user already know. So I am not convinced that this is the right way to do.
This is not true. ChunkDataset inherits from StatefulDataset. So it is very natural for a StatefulDatset to resume a state from construction. For the rest objects, it should be totally stateless thus no need to have a 'load-from-file' argument. |
|
As for your proposed approach:
To make it work here, we need to re-construct a ChunkDataset from a serialized file. We need deserialize constructor parameter which is the information below:
both 2 and 3 are problematic. For example, without knowing the sampler's constructor, we have no idea how to construct an instance and pass it to ChunkDataset's constructor. This is the deal breaker here. The chunk sampler for example can be any customized sampler that has a special constructor. Those information cannot be serialized out and re-used. We can add restriction saying torch::load only supports samplers known to pytorch. But this will basically chase away users who wants to use their own sampler. Having such an imposition doesn't seem like the right way to design this feature. |
|
Well, ok. For me it's a problem with the design of We add another instance of the |
Having a static Value::fromFile has the same problem as I mentioned before -- If Value is of type ChunkDataset, we cannot construct an instance of ChunkDataset from a file because we don't know the sampler and ChunkReader's constructor. I think this is a deal breaker for using any static ::load method. |
|
@apaszke if needed, both of you talk on a high-bandwidth channel (slack, messenger). please resolve this soon. |
|
Why don't we know the sampler? You're both saving and loading the chunk sampler as part of the newly added code, so I'm sorry but I don't understand the problem. Yeah, we should have a VC (let's talk on Slack). |
|
Synced on Slack. Based on the discussion, I have pushed a new iteration to the PR. |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, but I have three more comments:
- Can we please make the
saveandloadprivate methods andoperator<<andoperator>>friends of this class? Alternatively just make them friends and remove the methods. - I don't think there's a single test that would load and save a dataset from the file. You only keep saving the dataset and loading a sampler from the same tempfile, or the other way around. This doesn't seem like a good test, because it's not the desired usage. The fact that those objects have the same serialization format on-disk is a coincidence to some degree and should not be depended on.
- There's other state in the ChunkDataset like the batch size. Why isn't this saved?
This is the existing pattern for other classes too. So let's keep it consistent.
Updated the test with comment, explaining this is only for test purpose only. |
jaliyae
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
|
@pytorchbot merge this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
When dealing with large scale dataset, it is handy if we can save the dataset status and resume later. Especially in cases where some unexpected crash happens, user don't need to start over the whole dataset from begining. Instead, they can reload it from the last checkpoint.
This change adds support for checkpoint save/load logic in ChunkDataset.
On ChunkDataset construction, user can specify a file name from which to load the checkpoint. If it is empty, default to start from fresh; otherwise the ChunkDataset will 'fast forward' the chunk sampler to the corresponding checkpoint.
The user can also call ChunkDataset::save() to serialize current status to a file, which can be used later.