feat: Added model_fn to support non-standard model function in create_trainer (#3055)#3074
feat: Added model_fn to support non-standard model function in create_trainer (#3055)#3074vfdev-5 merged 6 commits intopytorch:masterfrom
Conversation
|
@invoker-bot thanks for the PR, please add also a test for this feature into https://github.com/pytorch/ignite/blob/master/tests/ignite/engine/test_create_supervised.py |
I have made these changes, please check it. |
vfdev-5
left a comment
There was a problem hiding this comment.
Thanks for the update @invoker-bot
Few improvements to add and it can be good to be merged
|
|
||
| loss[0] = mse_loss(_y_pred, _y).item() | ||
|
|
||
| # loss[0] = mse_loss(model(_x), _y).item() |
There was a problem hiding this comment.
Let's remove commented code
| theta[0] -= accumulation[0] / gradient_accumulation_steps | ||
| assert pytest.approx(model.fc.weight.data[0, 0].item(), abs=1.0e-5) == theta[0] | ||
| assert pytest.approx(trainer.state.output[-1], abs=1e-5) == loss[0] | ||
| print("loss:", loss[0], "theta:", theta[0]) |
|
@invoker-bot please run code style formatting script to fix CI issues: bash ./tests/run_code_style.sh install
bash ./tests/run_code_style.sh fmt |
|
@invoker-bot can you please address the comment such that the PR can be merged and will be included to the next release? |
I have fixed this issue now, please check it. |
|
@invoker-bot thanks, please also check above comments:
-> let me fix them myself to accelerate the review process |
vfdev-5
left a comment
There was a problem hiding this comment.
LGTM, thanks @invoker-bot
Fixes #3055
Description:
Now we can define our custom
model_fnincreate_supervised_trainerandcreate_supervised_evaluator.Check list: