Skip to content

Reinvoke sharding tests and refactor Clock constructor for sharded grids#5389

Merged
glwagner merged 58 commits intomainfrom
glw/sharding-fixes-1
Mar 23, 2026
Merged

Reinvoke sharding tests and refactor Clock constructor for sharded grids#5389
glwagner merged 58 commits intomainfrom
glw/sharding-fixes-1

Conversation

@glwagner
Copy link
Copy Markdown
Member

No description provided.

glwagner and others added 10 commits March 11, 2026 20:32
- Update CI sharding job to Julia 1.11.9, add XLA_FLAGS for 4 virtual devices
- Add reconstruct_global_field, synchronize_communication\!, set_to_field\! for sharded fields
- Add complete_communication_and_compute_buffer\! and interior_tendency_kernel_parameters
  no-ops for ShardedDistributed to disable pipelining
- Rename DistributedTripolarGridOfSomeKind to MPITripolarGridOfSomeKind to distinguish
  MPI-based distributed from Reactant sharded distributed
- Make synchronize_communication\! conditional on AsynchronousDistributed
- Create run_sharding_tests.jl as centralized test runner for all partition configs
- Rewrite test_sharded_lat_lon.jl and test_sharded_tripolar.jl to use run_sharding_tests.jl
- Use tolerance-based comparisons (sqrt(eps)) instead of strict equality
- Uncomment sharding tests in runtests.jl

Co-Authored-By: Claude Opus 4.6 <[email protected]>
…nings

Remove --project flag from mpiexec subprocess commands to match existing
MPI test patterns. The subprocess inherits JULIA_LOAD_PATH from Pkg.test()
which already includes MPI. Add include guard to distributed_tests_utils.jl
to prevent method redefinition warnings when included from multiple test files.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Move Reactant-specific run_distributed_simulation (with @compile) into
run_sharding_tests.jl. The shared distributed_tests_utils.jl is also used
by MPI tests which don't have Reactant available.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Define ReactantTestGrid type alias and dispatch run_distributed_simulation
on it, rather than overriding the generic method.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@navidcy navidcy added extensions 🧬 reactant ∇ all day I dream about MLIR labels Mar 12, 2026
@glwagner glwagner changed the title Refactor Clock function for sharding support Reinvoke sharding tests and add Clock constructor for sharded grids Mar 17, 2026
@glwagner glwagner changed the title Reinvoke sharding tests and add Clock constructor for sharded grids Reinvoke sharding tests and refactor Clock constructor for sharded grids Mar 17, 2026
glwagner and others added 4 commits March 17, 2026 20:39
…@test

Remove ShardedDistributedField set_to_function! override that launched a KA
kernel with closures, causing InvalidIRError. Falls through to ReactantField
method which does CPU fallback instead.

Also remove stale @test isnothing(flat_grid.z) — grid.z is now a
StaticVerticalDiscretization even for Flat topologies.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
- Introduce PeriodicFillHalo struct that captures Val{N}, Val{H} as type
  parameters, avoiding grid access inside periodic halo kernels. This
  fixes Reactant MLIR compilation failures from runtime grid field access.

- Rename fill_halo_kernel! → fill_halo_kernel (non-mutating builder)

- Split periodic_size_and_offset into periodic_size and periodic_offset
  with improved comments explaining windowed field offset correction

- Remove stale _fill_north_halo! overrides from test/distributed_tests_utils.jl

- Update OceananigansEnzymeExt for the renamed functions

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Add explicit Tuple{Any, Any} and Tuple{Any} signatures to resolve
method ambiguity with the generic fill_halo_event! methods.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@glwagner
Copy link
Copy Markdown
Member Author

glwagner commented Mar 21, 2026

@simone-silvestri this cleans up halo filling a bit --- there was some code in there that was badly golfed

# Use optimization_barrier to avoid aliasing last_Δt and last_stage_Δt,
# which causes XLA buffer donation errors in loops.
# (Δt + zero(Δt) gets folded by XLA, so we need a barrier.)
(last_stage_Δt,) = Reactant.Ops.optimization_barrier(Δt)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit weird, is there a reason last_Δt and last_stage_Δt shouldn't alias at the start [for ease]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively, we should do something less strong than an optimization barrier [and add a new op]. what happens is that in mlir if we see multiple returns of the same value, we just return one value for both [as they're the same]. Either we shouldn't do that, or alternatively we can add a new op that means "force_buffer_return" or something. The reason being that we actually do still want this to get optimized inside of loops, etc.

cc @Pangoraw

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they don't alias at the start but I think the issue is something like

function tick!(clock, dt)
    last_dt.mlir_data = dt.mlir_data
    last_stage_dt.mlir_data = dt.mlir_data
    return nothing
end

so they alias after this update in tick!, we need to copy the mlir_data?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make them Alia's at the start?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue here is that whether or not we actually need last_stage_dt depends on the time-stepping method. For RK3 these are different but for an earlier time-stepping method (which we are using for ocean in these tests and on gb25) they are the same

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking claude to figure it out

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to run but Claude is working on this (maybe will get it wrong), feel free to commit here in meantime

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this part be factored out in a follow up and the rest merged [since I know blocking some other tests]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in particualr I think you'll want to revert the ext/OceananigansReactantExt/OceananigansReactantExt.jl part of
ext/OceananigansReactantExt/OceananigansReactantExt.jl

back to the way it was before [aka]

function Oceananigans.TimeSteppers.tick!(clock::Oceananigans.TimeSteppers.Clock{<:Any, <:Any, <:Reactant.TracedRNumber}, Δt; stage=false)
    Oceananigans.TimeSteppers.tick_time!(clock, Δt)

    if stage # tick a stage update
        clock.stage += 1
        clock.last_stage_Δt = Δt
    else # tick an iteration and reset stage
        clock.iteration.mlir_data = (clock.iteration + 1).mlir_data
        clock.stage = 1
    end

    return nothing
end

and we can fix the time stepper generality in a follow up [as this PR otherwise contains many necessary unrelated fixes]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ill see if this gets tests to pass on gb25. But I think the reason we changed it is because the gb25 tests were not passing? One issue is that we have changed this function so there is nolonger a stage kwarg. I can figure out how to rewrite it though.

glwagner added a commit that referenced this pull request Mar 21, 2026
materialize_clock! in the constructor doesn't prevent MLIR-level aliasing
that occurs at runtime inside tick! when both last_Δt and last_stage_Δt
are assigned the same Δt value. The optimization_barrier is still needed
until Reactant provides a proper buffer copy op.

See #5389 (comment)

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 21, 2026

Codecov Report

❌ Patch coverage is 45.64103% with 106 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.91%. Comparing base (04fd2ed) to head (e1dbdb6).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
src/Operators/spacings_and_areas_and_volumes.jl 1.78% 55 Missing ⚠️
ext/OceananigansReactantExt/TimeSteppers.jl 40.42% 28 Missing ⚠️
src/BoundaryConditions/fill_halo_kernels.jl 82.92% 7 Missing ⚠️
ext/OceananigansReactantExt/Fields.jl 0.00% 3 Missing ⚠️
...c/BoundaryConditions/fill_halo_regions_periodic.jl 0.00% 3 Missing ⚠️
ext/OceananigansEnzymeExt.jl 0.00% 2 Missing ⚠️
ext/OceananigansReactantExt/Models.jl 0.00% 2 Missing ⚠️
...OceananigansReactantExt/OceananigansReactantExt.jl 83.33% 2 Missing ⚠️
ext/OceananigansReactantExt/OutputReaders.jl 0.00% 2 Missing ⚠️
...rthogonalSphericalShellGrids/distributed_zipper.jl 90.90% 1 Missing ⚠️
... and 1 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5389      +/-   ##
==========================================
+ Coverage   68.84%   73.91%   +5.07%     
==========================================
  Files         398      398              
  Lines       21963    22732     +769     
==========================================
+ Hits        15121    16803    +1682     
+ Misses       6842     5929     -913     
Flag Coverage Δ
buildkite 68.76% <27.60%> (-0.09%) ⬇️
julia 68.76% <27.60%> (-0.09%) ⬇️
reactant_unit 3.79% <15.46%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

clock.stage = 1
clock.last_Δt.mlir_data = Δt.mlir_data
clock.last_stage_Δt.mlir_data = Δt.mlir_data
clock.last_stage_Δt.mlir_data = promote_to_traced(Δt).mlir_data
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses how about this (create a copy of dt and use that)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this will not work, it'll still be seen as the same value at the end

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would expect it to fail

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we restore this to the old functionality, merge all the sharding stuff, then come back to this when @Pangoraw added the attr stuff?

@wsmoses
Copy link
Copy Markdown
Collaborator

wsmoses commented Mar 23, 2026

Avoiding MLIR data won’t work either and is necessary, I think you just need to revert, merge this, and wait for the upstream fix

@glwagner
Copy link
Copy Markdown
Member Author

we cleared the reactant tests, and also the issue on gb25 is resolved so I am crossing my fingers this will get merged soon

@glwagner
Copy link
Copy Markdown
Member Author

note on the above discussion: the solution was to add a materialize_clock!(clock, timestepper) switch --- for AB2, which has no stages, we force last_dt and last_stage_dt to alias initially and this solves the problem. The fix is in src not the extension, since the fix is innocuous for non-reactant situations. materialize_clock! is called in the constructor for HydrostaticFreeSurfaceModel.

@glwagner glwagner merged commit 4dd6137 into main Mar 23, 2026
67 of 68 checks passed
@glwagner glwagner deleted the glw/sharding-fixes-1 branch March 23, 2026 19:20
end

@kernel function _fill_periodic_bottom_and_top_halo!(c, bottom_bc, top_bc, loc, grid, args)
@kernel function _fill_periodic_bottom_and_top_halo!(c, ::Val{N}, ::Val{H}) where {N, H}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why passing Vals here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that knowing loop size at compile time yields to successful compilation vs not. This is actually needed for julia 1.12 with reactant.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, cool

glwagner added a commit that referenced this pull request Mar 26, 2026
* Fix XLA buffer donation error from aliased clock fields in tick!

In tick!, both clock.last_Δt and clock.last_stage_Δt were set to the
same Δt.mlir_data, causing them to alias. Inside @trace for loops,
this triggers "Attempt to donate a buffer which is also used by the
same call to Execute()" because XLA tries to donate one while reading
the other in the while loop carry. Break the alias by creating a copy
via Δt + zero(Δt) for last_stage_Δt.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Remove initialization_update_state! and clean up initialization

Replace the confusing `initialization_update_state!` with a clear separation
of concerns between `reconcile_state!`, `initialize!`, and `update_state!`:

- `reconcile_state!`: ensures auxiliary state (barotropic velocities, vertical
  coordinate scaling) is consistent with prognostic fields. Called in `set!`
  and `initialize!`. Idempotent.

- `initialize!`: one-time setup before first time step. Calls `reconcile_state!`,
  synchronous halo fills, and `initialize_closure_fields!`. Allowed to be
  non-idempotent.

- `update_state!`: idempotent recomputation of diagnostics/derived quantities.
  Called every time step and after `set!`.

Renames:
- `initialize_free_surface!` → `reconcile_free_surface!`
- `initialize_vertical_coordinate!` → `reconcile_vertical_coordinate!`
- `maybe_initialize_state!` → `maybe_initialize!` (now also calls `initialize!`)

The `set!` function gains a `reconcile_state` kwarg (default `true`) for
advanced users who set barotropic velocities independently.

Bump version to 0.106.1.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Remove initialize! call from maybe_initialize!

maybe_initialize! should only call update_state!, not initialize!.
Calling initialize! inside time_step! would cause double-initialization
when used with Simulation (which already calls initialize! in
Simulation.initialize!).

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Rename maybe_initialize! to maybe_prepare_first_time_step!

Also move reconcile_state! definition to TimeSteppers (where update_state!
lives) and call it from maybe_prepare_first_time_step! so that bare
time_step!(model, Δt) reconciles state on the first step.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Use optimization_barrier instead of +zero to break clock field alias

XLA folds Δt + zero(Δt) → Δt, so the aliasing between last_Δt and
last_stage_Δt persists. Use Reactant.Ops.optimization_barrier which
XLA cannot fold, ensuring distinct buffers for the while loop carry.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Add materialize_clock! and remove optimization_barrier from tick!

Add `materialize_clock!(clock, timestepper)` called in HFSM and NH model
constructors. For QuasiAdamsBashforth2TimeStepper, this sets
`clock.last_Δt = clock.last_stage_Δt` to ensure they are distinct objects.
This is needed for Reactant, where aliased ConcreteRNumber fields cause
XLA buffer donation errors in compiled loops.

With materialize_clock! breaking the alias at construction time, the
optimization_barrier workaround in the Reactant tick! override is no
longer needed and is removed.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Restore optimization_barrier in Reactant tick!

materialize_clock! in the constructor doesn't prevent MLIR-level aliasing
that occurs at runtime inside tick! when both last_Δt and last_stage_Δt
are assigned the same Δt value. The optimization_barrier is still needed
until Reactant provides a proper buffer copy op.

See #5389 (comment)

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Remove optimization_barrier from Reactant tick!

materialize_clock! in the model constructor ensures last_Δt and
last_stage_Δt are distinct ConcreteRNumber objects. Assigning the same
value to both in tick! via .mlir_data does not re-alias the objects.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Move materialize_clock! QAB2 method to Reactant extension

For QAB2, last_Δt and last_stage_Δt always hold the same value after
tick!. In the Reactant extension, materialize_clock! aliases them via
setfield! (bypassing the ReactantClock setproperty! override) so that
Reactant's tracer sees one buffer, avoiding XLA buffer donation errors.

The src/ QAB2 method is removed since aliasing is only needed for
Reactant's ConcreteRNumber fields (normal Float64 fields are immutable
values with no aliasing concern).

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* rm stale import

* fill free surface halos before timestepping

* remove stale imports

* add fill halos in reconcile_state!

* now just need to fix multi-region

* cubed sphere implementation

* fix the dispatch

---------

Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
Co-authored-by: Simone Silvestri <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

extensions 🧬 reactant ∇ all day I dream about MLIR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants