Prefer DirectML for Windows ONNX transcription models#985
Prefer DirectML for Windows ONNX transcription models#985ferologics wants to merge 2 commits intocjpais:mainfrom
Conversation
3b27bd4 to
32b0150
Compare
🧪 Test Build ReadyBuild artifacts for PR #985 are available for testing. Download artifacts from workflow run Artifacts expire after 30 days. |
|
@ferologics can you see if this helps the inference speed for you still, would be quite curious if it just goes back to CPU or works out the box. I kind of think since it's direct ML it might just work on Win 11, would be curious about Win10 too |
|
Tested the CI-built Windows artifact locally and it looks good. What I checked:
Result:
So on my Windows 11 machine the CI-built artifact is still taking the intended GPU path, not silently dropping back to CPU. |
|
Solid. This is amazing news. I will test on my windows machine when I can and see how it goes as well. I'm curious how this will play with integrated GPUs I am slightly wondering if we will need to provide an option to disable this. Just in case CPU is faster for someone. I know another PR in transcribe rs had something like this. Might be worth considering |
|
Good callout. I agree an opt-out could be useful just in case CPU ends up better for some setups. I’m not sure it needs to be tackled in this PR unless you think it’s important for landing it — happy to add it if you feel it’s essential, otherwise we can keep this one focused and follow up separately. |
|
Let me think about it, I want to give this a test myself on my machine and go from there. I will be able to test probably tomorrow |
|
Okay I gave this a quick run. We definitely need a toggle before shipping this. The default should be off. It probably should be in experimental settings. Possibly, it should be a dropdown since we may add CUDA, etc to it in the future. Not sure exactly how we are going to handle this generically but we will cross that bridge when we get there. The reason for this is DirectML is 4x slower than CPU on my test machine with an integrated GPU (testing with parakeet v3). I suspect a lot of users have integrated GPU's and we cannot impact their performance. Maybe #958 is relevant here and worth combining efforts. Pinging @andrewleech for thoughts and opinions. I know there is also #1023, which we need to do.. Also cc: @intech. I am not ready to move to 0.3.0 of transcribe-rs quite yet. Mostly because we need to have a solid design for supporting acceleration in the app. Basically all these PR's are interrelated, so would love any help thinking about this. Opinions and thoughts welcome. I will likely be making some changes to |
|
In #958 I was testing Direct ML as well as WebGPU and generally found on my iGPU they gave worse performance on most models depending on model architecture and format/quantizing. I didn't keep DirectML by default because it's in maintenance / effectively deprecated, as far as I could tell WebGPU is the most supported framework that supports both Nvidia and AMD. However WebGPU was slightly slower for me than DirectML on my test iirc, though there are some settings in WebGPU needed to ensure it's not using the "default browser settings" restrictions. Cuda would likely give better performance for compatible hardware but I think ORT then bundles the ~100M binaries so you probably don't want it included / enabled by default. My PR adds -- compile flags to select which GPU frameworks too include, along with drop-down setting to choose what's enabled. |
|
Closing because I will be submitting a PR for this and pulling it in |
Summary
transcribe-rsdependency to a forked git revision with Windows DirectML support for ONNX modelsDirectMLExecutionProvideron Windows, with explicit CPU fallback if provider registration failsValidation
cargo checkcargo check --releasehandy.logshows successful DirectML registration for the Parakeet ONNX sessions244.38saudio):35.368sbefore vs6.99safter (~5.1xfaster, ~35xrealtime)Dependency patch
transcribe-rsto this git revision:ferologics/transcribe-rs@c56480687127070f456ae462d73c5defe964d807transcribe-rsmainline: