Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Jan 17, 2023

Stack from ghstack:

Overview

  • This PR refactors the summon_full_params() unit tests to prepare for unshard_params() by consolidating redundant tests and improving others.
  • This PR enables CPUOffload(offload_params=True) + NO_SHARD + writeback=True.
  • This PR provides an improved error message when calling summon_full_params() from an invalid context (i.e. from forward, backward, or in summon_full_params()).

Details

Existing Unit Tests

test_summon_full_param_writeback() with world_size=1
test_summon_full_param_writeback() with world_size=2

  • Tests that writeback=True persists write and that writeback=False does not persist write when modifying a root FSDP instance's flat_param (modify_outer=True) or a non-root FSDP instance's flat_param (modify_outer=False); additionally configures with mixed_precision and use_orig_params
  • CPUOffload(offload_params=True) + world_size=1 is not tested because it is not supported.
  • The write inside summon_full_params() is on the flat_param itself, which is not the expected usage.

test_summon_full_param_shard_value()

  • Tests that reconstructing the flat_param (by re-flattening and chunking parameters) inside summon_full_params() gives the same as the originally constructed flat_param when using a single FSDP instance
  • This test seems to exercise the FSDP sharding algorithm, not the specification of summon_full_params(). The only relevant part being implicitly tested is that model.parameters() order is preserved.
  • This test assumes the current FSDP sharding algorithm.

test_summon_full_param_recursive()

  • Tests that recurse=True recursively applies to all FSDP instances and that recurse=False does not
  • This test assumes the current FSDP sharding algorithm.

test_cannot_summon_full_params_from_forward()
test_cannot_summon_full_params_from_backward()

  • Tests that calling summon_full_params() from inside the forward or backward raises an error
  • The error message leaks FlatParamHandle to the user. I provided a better error in this PR.

test_summon_full_params_respects_reshard_after_forward()

  • Tests that calling summon_full_params() after forward preserves whether the padded unsharded flat_param data is freed or not (like reshard_after_forward)
  • This test depends on FSDP internals (flat_param._full_param_padded.storage().size()).

test_summon_single_param()

  • Tests that writing to padding with writeback=True does not persist those writes (doing so by using a singleton (1, 1) parameter that gets flattened and padded to (2,))
  • This test name is misleading.

test_summon_full_params_equivalence()

  • Tests writeback, rank0_only, and offload_to_cpu with writeback=not rank0_only, using CPUOffload(offload_params=True) and including a torch.cuda._sleep(int(1e6)) after the write in summon_full_params()
  • The PR introducing this test said that the torch.cuda._sleep(int(1e6)) exercised the stream synchronization in summon_full_params()--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the torch.cuda._sleep() call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since summon_full_params() unshards in the default stream now and does not require stream synchronization.
  • This unit test has overlap with test_summon_full_param_writeback() and can be coalesced.

test_summon_from_non_fsdp()

  • Tests calling summon_full_params() with default args on a non-FSDP root module exposes the original parameters correctly
  • This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking.

test_reshard_outside_forward_backward_iteration()

  • Tests that calling summon_full_params() after forward preserves whether the padded unsharded flat_param data is freed or not (like reshard_after_forward) and that calling summon_full_params() after backward preserves that the padded unsharded flat_param data are freed; additionally configures mixed_precision
  • This test strictly dominates test_summon_full_params_respects_reshard_after_forward() in strictness since it includes the check after backward as well.

test_params_are_unflattenned()

  • Tests that original parameters are exposed with the unflattened shape factoring in rank0_only (e.g. including that nonzero ranks reshard early when rank0_only=True) and that with offload_to_cpu=True, the flat_params are moved back to GPU after exiting the context; additionally configures mixed_precision

test_params_count_and_value()

  • Tests that original parameters are all exposed and with the correct values factoring in rank0_only (e.g. including that nonzero ranks do not expose the original parameters when rank0_only=True) and that with offload_to_cpu=True, the flat_params are moved back to GPU after exiting the context; additionally configures mixed_precision

test_raises_rank0_with_writeback()

  • Tests that rank0_only + writeback=True raises an error

test_named_parameters_buffers()

  • Tests that named_parameters() and named_buffers() return clean names (without FSDP prefixes) inside summon_full_params()

test_with_grads_core()

  • Tests with_grads=True by comparing against DDP

test_with_grads_none_grads()

  • Tests with_grads=True when ranks' FlatParameters have None gradient
New Unit Tests

test_unshard_params_writeback_no_shard() (with world_size=1)
test_unshard_params_writeback() (with world_size=2)

  • Tests the writeback argument (using the default value for all others)

test_unshard_params_param_data_no_shard() (with world_size=1)
test_unshard_params_param_data() (with world_size=2)

  • Tests that parameters are exposed correctly for recurse=True and all other argument configs for a non-FSDP root module

test_unshard_singleton_param_writeback()

  • Tests writeback=True for a singleton parameter, which includes testing that writing to padding does not persist

test_unshard_params_respects_reshard()

  • Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward

test_unshard_params_recurse()

  • Tests the recurse argument (using default for all others)

test_offload_to_cpu_no_shard_raises()

  • Tests that offload_to_cpu=True with NO_SHARD raises an error
Summary of Unit Test Changes
  • test_summon_full_param_writeback -> test_unshard_params_writeback()
  • test_summon_full_params_equivalence(), test_params_are_unflattenned(), test_params_count_and_value() -> test_unshard_params_param_data()
  • test_summon_full_params_respects_reshard_after_forward(), test_reshard_outside_forward_backward_iteration() -> test_unshard_params_respects_reshard()
  • test_summon_full_param_recursive() -> test_unshard_params_recurse()
  • test_named_parameters_and_buffers() unchanged
  • test_with_grads_core() unchanged
  • test_with_grads_none_grads() unchanged
  • test_cannot_summon_full_params_from_forward(), test_cannot_summon_full_params_from_backward() -> test_unshard_params_from_forward_raises(), test_unshard_params_from_backward_raises()
  • test_raises_rank0_with_writeback() -> test_rank0_only_with_writeback_raises()
  • test_offload_to_cpu_no_shard_raises() new
  • test_summon_full_param_shard_value() removed

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92298

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 33c32ce:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jan 17, 2023
awgu pushed a commit that referenced this pull request Jan 17, 2023
ghstack-source-id: bb33665
Pull Request resolved: #92298
@awgu awgu added the topic: not user facing topic category label Jan 17, 2023
@awgu awgu marked this pull request as ready for review January 17, 2023 17:04
**Overview**
- This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others.
- This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`.
- This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`).

**Details**
<details>
<summary>Existing Unit Tests</summary>

`test_summon_full_param_writeback()` with `world_size=1`
`test_summon_full_param_writeback()` with `world_size=2`
- Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params`
- `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported.
- The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage.

`test_summon_full_param_shard_value()`
- Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance
- This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved.
- This test assumes the current FSDP sharding algorithm.

`test_summon_full_param_recursive()`
- Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not
- This test assumes the current FSDP sharding algorithm.

`test_cannot_summon_full_params_from_forward()`
`test_cannot_summon_full_params_from_backward()`
- Tests that calling `summon_full_params()` from inside the forward or backward raises an error
- The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR.

`test_summon_full_params_respects_reshard_after_forward()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`)
- This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`).

`test_summon_single_param()`
- Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`)
- This test name is misleading.

`test_summon_full_params_equivalence()`
- Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()`
- The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization.
- This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced.

`test_summon_from_non_fsdp()`
- Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly
- This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking.

`test_reshard_outside_forward_backward_iteration()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision`
- This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well.

`test_params_are_unflattenned()`
 - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_params_count_and_value()`
- Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_raises_rank0_with_writeback()`
- Tests that `rank0_only` + `writeback=True` raises an error

`test_named_parameters_buffers()`
- Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()`

`test_with_grads_core()`
- Tests `with_grads=True` by comparing against DDP

`test_with_grads_none_grads()`
- Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient

</details>


<details>
<summary>New Unit Tests</summary>

`test_unshard_params_writeback_no_shard()` (with `world_size=1`)
`test_unshard_params_writeback()` (with `world_size=2`)
- Tests the `writeback` argument (using the default value for all others)

`test_unshard_params_param_data_no_shard()` (with `world_size=1`)
`test_unshard_params_param_data()` (with `world_size=2`)
- Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module

`test_unshard_singleton_param_writeback()`
- Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist

`test_unshard_params_respects_reshard()`
- Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward

`test_unshard_params_recurse()`
- Tests the `recurse` argument (using default for all others)

`test_offload_to_cpu_no_shard_raises()`
- Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error

</details>

<details>
<summary>Summary of Unit Test Changes</summary>

- `test_summon_full_param_writeback` -> `test_unshard_params_writeback()`
- `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()`
- `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()`
- `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()`
- `test_named_parameters_and_buffers()` unchanged
- `test_with_grads_core()` unchanged
- `test_with_grads_none_grads()` unchanged
- `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()`
- `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()`
- `test_offload_to_cpu_no_shard_raises()` new
- `test_summon_full_param_shard_value()` removed


</details>


[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jan 17, 2023
ghstack-source-id: 17cc6b3
Pull Request resolved: #92298
**Overview**
- This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others.
- This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`.
- This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`).

**Details**
<details>
<summary>Existing Unit Tests</summary>

`test_summon_full_param_writeback()` with `world_size=1`
`test_summon_full_param_writeback()` with `world_size=2`
- Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params`
- `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported.
- The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage.

`test_summon_full_param_shard_value()`
- Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance
- This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved.
- This test assumes the current FSDP sharding algorithm.

`test_summon_full_param_recursive()`
- Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not
- This test assumes the current FSDP sharding algorithm.

`test_cannot_summon_full_params_from_forward()`
`test_cannot_summon_full_params_from_backward()`
- Tests that calling `summon_full_params()` from inside the forward or backward raises an error
- The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR.

`test_summon_full_params_respects_reshard_after_forward()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`)
- This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`).

`test_summon_single_param()`
- Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`)
- This test name is misleading.

`test_summon_full_params_equivalence()`
- Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()`
- The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization.
- This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced.

`test_summon_from_non_fsdp()`
- Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly
- This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking.

`test_reshard_outside_forward_backward_iteration()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision`
- This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well.

`test_params_are_unflattenned()`
 - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_params_count_and_value()`
- Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_raises_rank0_with_writeback()`
- Tests that `rank0_only` + `writeback=True` raises an error

`test_named_parameters_buffers()`
- Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()`

`test_with_grads_core()`
- Tests `with_grads=True` by comparing against DDP

`test_with_grads_none_grads()`
- Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient

</details>


<details>
<summary>New Unit Tests</summary>

`test_unshard_params_writeback_no_shard()` (with `world_size=1`)
`test_unshard_params_writeback()` (with `world_size=2`)
- Tests the `writeback` argument (using the default value for all others)

`test_unshard_params_param_data_no_shard()` (with `world_size=1`)
`test_unshard_params_param_data()` (with `world_size=2`)
- Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module

`test_unshard_singleton_param_writeback()`
- Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist

`test_unshard_params_respects_reshard()`
- Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward

`test_unshard_params_recurse()`
- Tests the `recurse` argument (using default for all others)

`test_offload_to_cpu_no_shard_raises()`
- Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error

</details>

<details>
<summary>Summary of Unit Test Changes</summary>

- `test_summon_full_param_writeback` -> `test_unshard_params_writeback()`
- `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()`
- `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()`
- `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()`
- `test_named_parameters_and_buffers()` unchanged
- `test_with_grads_core()` unchanged
- `test_with_grads_none_grads()` unchanged
- `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()`
- `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()`
- `test_offload_to_cpu_no_shard_raises()` new
- `test_summon_full_param_shard_value()` removed


</details>


[ghstack-poisoned]
**Overview**
- This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others.
- This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`.
- This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`).

**Details**
<details>
<summary>Existing Unit Tests</summary>

`test_summon_full_param_writeback()` with `world_size=1`
`test_summon_full_param_writeback()` with `world_size=2`
- Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params`
- `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported.
- The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage.

`test_summon_full_param_shard_value()`
- Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance
- This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved.
- This test assumes the current FSDP sharding algorithm.

`test_summon_full_param_recursive()`
- Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not
- This test assumes the current FSDP sharding algorithm.

`test_cannot_summon_full_params_from_forward()`
`test_cannot_summon_full_params_from_backward()`
- Tests that calling `summon_full_params()` from inside the forward or backward raises an error
- The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR.

`test_summon_full_params_respects_reshard_after_forward()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`)
- This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`).

`test_summon_single_param()`
- Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`)
- This test name is misleading.

`test_summon_full_params_equivalence()`
- Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()`
- The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization.
- This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced.

`test_summon_from_non_fsdp()`
- Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly
- This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking.

`test_reshard_outside_forward_backward_iteration()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision`
- This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well.

`test_params_are_unflattenned()`
 - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_params_count_and_value()`
- Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_raises_rank0_with_writeback()`
- Tests that `rank0_only` + `writeback=True` raises an error

`test_named_parameters_buffers()`
- Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()`

`test_with_grads_core()`
- Tests `with_grads=True` by comparing against DDP

`test_with_grads_none_grads()`
- Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient

</details>


<details>
<summary>New Unit Tests</summary>

`test_unshard_params_writeback_no_shard()` (with `world_size=1`)
`test_unshard_params_writeback()` (with `world_size=2`)
- Tests the `writeback` argument (using the default value for all others)

`test_unshard_params_param_data_no_shard()` (with `world_size=1`)
`test_unshard_params_param_data()` (with `world_size=2`)
- Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module

`test_unshard_singleton_param_writeback()`
- Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist

`test_unshard_params_respects_reshard()`
- Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward

`test_unshard_params_recurse()`
- Tests the `recurse` argument (using default for all others)

`test_offload_to_cpu_no_shard_raises()`
- Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error

</details>

<details>
<summary>Summary of Unit Test Changes</summary>

- `test_summon_full_param_writeback` -> `test_unshard_params_writeback()`
- `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()`
- `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()`
- `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()`
- `test_named_parameters_and_buffers()` unchanged
- `test_with_grads_core()` unchanged
- `test_with_grads_none_grads()` unchanged
- `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()`
- `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()`
- `test_offload_to_cpu_no_shard_raises()` new
- `test_summon_full_param_shard_value()` removed


</details>


[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jan 19, 2023
ghstack-source-id: dcee8e8
Pull Request resolved: #92298
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for enhancing the testing!

**Overview**
- This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others.
- This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`.
- This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`).

**Details**
<details>
<summary>Existing Unit Tests</summary>

`test_summon_full_param_writeback()` with `world_size=1`
`test_summon_full_param_writeback()` with `world_size=2`
- Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params`
- `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported.
- The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage.

`test_summon_full_param_shard_value()`
- Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance
- This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved.
- This test assumes the current FSDP sharding algorithm.

`test_summon_full_param_recursive()`
- Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not
- This test assumes the current FSDP sharding algorithm.

`test_cannot_summon_full_params_from_forward()`
`test_cannot_summon_full_params_from_backward()`
- Tests that calling `summon_full_params()` from inside the forward or backward raises an error
- The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR.

`test_summon_full_params_respects_reshard_after_forward()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`)
- This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`).

`test_summon_single_param()`
- Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`)
- This test name is misleading.

`test_summon_full_params_equivalence()`
- Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()`
- The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization.
- This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced.

`test_summon_from_non_fsdp()`
- Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly
- This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking.

`test_reshard_outside_forward_backward_iteration()`
- Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision`
- This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well.

`test_params_are_unflattenned()`
 - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_params_count_and_value()`
- Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision`

`test_raises_rank0_with_writeback()`
- Tests that `rank0_only` + `writeback=True` raises an error

`test_named_parameters_buffers()`
- Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()`

`test_with_grads_core()`
- Tests `with_grads=True` by comparing against DDP

`test_with_grads_none_grads()`
- Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient

</details>


<details>
<summary>New Unit Tests</summary>

`test_unshard_params_writeback_no_shard()` (with `world_size=1`)
`test_unshard_params_writeback()` (with `world_size=2`)
- Tests the `writeback` argument (using the default value for all others)

`test_unshard_params_param_data_no_shard()` (with `world_size=1`)
`test_unshard_params_param_data()` (with `world_size=2`)
- Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module

`test_unshard_singleton_param_writeback()`
- Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist

`test_unshard_params_respects_reshard()`
- Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward

`test_unshard_params_recurse()`
- Tests the `recurse` argument (using default for all others)

`test_offload_to_cpu_no_shard_raises()`
- Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error

</details>

<details>
<summary>Summary of Unit Test Changes</summary>

- `test_summon_full_param_writeback` -> `test_unshard_params_writeback()`
- `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()`
- `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()`
- `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()`
- `test_named_parameters_and_buffers()` unchanged
- `test_with_grads_core()` unchanged
- `test_with_grads_none_grads()` unchanged
- `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()`
- `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()`
- `test_offload_to_cpu_no_shard_raises()` new
- `test_summon_full_param_shard_value()` removed


</details>


[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 2, 2023
ghstack-source-id: d28beaa
Pull Request resolved: #92298
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 2, 2023
@awgu
Copy link
Collaborator Author

awgu commented Feb 2, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ragulpr added a commit to ragulpr/pytorch that referenced this pull request Feb 2, 2023
…n-dev-setup

* origin: (898 commits)
  Move dynamo.optimizations.distributed to backends (pytorch#93408)
  Remove cuda 11.6 from nightly (pytorch#93979)
  Refactor dynamo register_backend/BACKENDS (pytorch#93389)
  Remove cuda 11.6 from CI replace with 11.7 (pytorch#93406)
  [Dynamo] Rename `GuardBuilder.guarded_code` -> `check_fn_manager` (pytorch#93934)
  Revert "Remove CUDA 11.6 from nightly builds (pytorch#93404)"
  Revert "[inductor] fix crash issue when input is a view tensor (pytorch#90150)"
  Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (pytorch#93396)
  Merge Inductor perf smoke test with other inductor CI tests (pytorch#93395)
  [inductor] Don't import torchvision (pytorch#93027)
  [FSDP][3/N] Refactor `summon_full_params` unit tests (pytorch#92298)
  [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (pytorch#92297)
  Remove CUDA 11.6 from nightly builds (pytorch#93404)
  Mark buffers that reuse other buffers (pytorch#93329)
  Refactor to allow reuse of SchedulerNode.allocate (pytorch#93328)
  retire sparse_mask_helper (pytorch#91714)
  update fbgemm third party (pytorch#93907)
  [inductor] fix crash issue when input is a view tensor (pytorch#90150)
  [Inductor] add config for weight prepacking (pytorch#93811)
  Check for none for NNModuleVariable.__module__ (pytorch#93326)
  ...
@facebook-github-bot facebook-github-bot deleted the gh/awgu/303/head branch June 8, 2023 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants