Skip to content

Conversation

@yuslepukhin
Copy link
Member

@yuslepukhin yuslepukhin commented Aug 22, 2025

Description

While memory profiling some models I noticed multiple file mapping failures.
WindowsEnv::MapFileIntoMemory() While it properly checks for the mapping offset to be granularity
aligned, it calculates it as page aligned.
Also, while saving external tensors we do not need to align big tensors to windows granularity or anything
that is platform dependent. Set it to 4096 for all platforms.
Granularity matters only for calculating mapping address.

Motivation and Context

Multiple failures for file mapping for certain models.
This saves some hundreds of Mbs for some models.

    While it properly checks for the mapping offset to be granularity
    aligned, it calculates it as page aligned.
    Also, we donot need to align big tensors to windwows granularity or anything
    that is platform dependent. Set it to 4096 for all platforms.
    Granularity matters only for calculating mapping address.
@justinchuby
Copy link
Contributor

justinchuby commented Aug 23, 2025

Could you say more on the alignment granularity/ why we don't need to align to 64kb for windows? I suppose https://github.com/onnx/ir-py/blob/ce1d0e63c41104b584271090ecd1b4a8f8bec52f/src/onnx_ir/external_data.py#L33 can be changed the same way too if that's the case? Thanks

@yuslepukhin
Copy link
Member Author

Could you say more on the alignment granularity/ why we don't need to align to 64kb for windows? I suppose https://github.com/onnx/ir-py/blob/ce1d0e63c41104b584271090ecd1b4a8f8bec52f/src/onnx_ir/external_data.py#L33 can be changed the same way too if that's the case? Thanks

The allocation granularity requirement applies to a mapping offset of the file, not to the way data is written on disk.
Even today, we only align data sizes above certain thresholds (>1Mb)_on disk and pad them, everything else is just written sequentially. There is not a need to align it to the allocation granularity on disk which is a lot of waste on Windows.
It is important that the mapping offset is aligned, and that was not done.
For that we shift back the offset when mapping, but we return the offset to the beginning of the data.
I tested it and it the mapping now always succeeds (before the fix, it failed most of the time).

@yuslepukhin
Copy link
Member Author

There are more ways to improve mappings today. Because of the allocation granularity and default externalizing threshold of 1K, we have many weights fall into the same mapping resulting the same disk area being mapped multiple times, up to 64 times on windows. We can do better here and map certain regions only ones OR we can decide to map the entire file and ref count it.

@yuslepukhin yuslepukhin marked this pull request as ready for review August 25, 2025 17:49
jywu-msft
jywu-msft previously approved these changes Aug 26, 2025
@yuslepukhin yuslepukhin merged commit 568ad20 into main Aug 27, 2025
99 of 103 checks passed
@yuslepukhin yuslepukhin deleted the yuslepukhin/fix_memapping_windows branch August 27, 2025 17:26
snnn pushed a commit that referenced this pull request Aug 28, 2025
### Description
<!-- Describe your changes. -->
While memory profiling some models I noticed multiple file mapping
failures.
`WindowsEnv::MapFileIntoMemory()` While it properly checks for the
mapping offset to be granularity
  aligned, it calculates it as page aligned.
Also, while saving external tensors we do not need to align big tensors
to windows granularity or anything
  that is platform dependent. Set it to 4096 for all platforms.
  Granularity matters only for calculating mapping address.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Multiple failures for file mapping for certain models.
This saves some hundreds of Mbs for some models.
snnn pushed a commit that referenced this pull request Aug 29, 2025
- **Relax WeightBiasQuantization constraint for larger QDQ node group
(#25673)**
- **Add cuda graph implementation for NV TRT RTX EP (#25787)**
- **python GPU IO Bindings for NVIDIA  (#25776)**
- **Fixes for DynamicQuantizeMatMul and Attention3D tests (#25814)**
- **Fix a long standing bug on file memory mapping on windows.
(#25833)**
- **Add API for precompiled model compatibility check using just the
compat info (#25841)**
- **Enable ABSL_FLAGS flag registration for onnxruntime_perf_test for
mobile build (#25849)**
- **Add default constructor to Ort::Status. (#25860)**
- #25871
- #25878
- #25884
- #25886
- #25866
@snnn
Copy link
Contributor

snnn commented Aug 30, 2025

The change is added to the release branch

gedoensmax pushed a commit to gedoensmax/onnxruntime that referenced this pull request Sep 2, 2025
…#25833)

### Description
<!-- Describe your changes. -->
While memory profiling some models I noticed multiple file mapping
failures.
`WindowsEnv::MapFileIntoMemory()` While it properly checks for the
mapping offset to be granularity
  aligned, it calculates it as page aligned.
Also, while saving external tensors we do not need to align big tensors
to windows granularity or anything
  that is platform dependent. Set it to 4096 for all platforms.
  Granularity matters only for calculating mapping address.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Multiple failures for file mapping for certain models.
This saves some hundreds of Mbs for some models.
yuslepukhin added a commit that referenced this pull request Oct 30, 2025
…lues early (#26345)

### Description
Converts weights early and revert "Properly remove in-memory references
(#25652)"
This reverts commit 3ca49d8 and makes
appropriate adjustments for the current state of the code.

This PR is made possible and on the heels of:
#26263
#25833.

Previous history:
#23979
#25320
#25626
#25652

The first change (#26263)
allows us to convert initializers to OrtValues early and save lots of
memory at model loading time.

Specifically, for Phi-4-mini-instruct-INT4 model before and after looks
like this:

**Before**
<img width="1204" height="124" alt="Before change DEBUG 2025-10-16
144819"
src="https://github.com/user-attachments/assets/674ff75b-057f-498a-a906-0140d59d46e6"
/>

**After**

<img width="997" height="114" alt="After change DEBUG 2025-10-16 144819"
src="https://github.com/user-attachments/assets/df1783af-7f50-4cd2-b3ad-6868f23be53f"
/>

The two peaks represent memory usage at optimization time (8.1Gb before)
and after weights memory mapping (6.5Gb)
After this change corresponding numbers look 3.5Gb and 4.7Gb
respectively.
Most of the savings during optimization phase come from
`ConstantFolding` where we are able to reuse the resulting OrtValues
directly for the new initializers.

This PR concludes a series of PRs converting initializers to OrtValues.

Memory consumption before the conversion began was 9.3Gb and 6.7Gb
respectively. We are saving almost 6Gb during optimization and 2Gb for
the steady state.
 
 
<img width="1175" height="139" alt="image"
src="https://github.com/user-attachments/assets/80e7d228-8a8e-4316-8e04-b02c2be30f04"
/>

The model also loads about 12 seconds faster.

Example of ConstantFolding being one of the top contributors where we
duplicate memory for higher peak before Resolve takes care of no longer
used initializers.
<img width="1100" height="558" alt="Sanpshot 3 Peak on ConstantFolding
Transpose Optimizer"
src="https://github.com/user-attachments/assets/95545abd-3f99-46d9-862e-bbf27cbb5b40"
/>

<img width="1060" height="600" alt="Snapshot 4 Peak AddInitializer from
ConstantFolding"
src="https://github.com/user-attachments/assets/dd457ec6-23ee-4efd-8c60-625d5faad61e"
/>

<img width="325" height="160" alt="image"
src="https://github.com/user-attachments/assets/37c1194d-f683-49a7-afb1-073dfbb9bbfc"
/>


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Reduce memory usage.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants