Skip to content

[ONNX] STFT Support#92087

Closed
urinieto wants to merge 12 commits intopytorch:masterfrom
urinieto:stft_onnx2
Closed

[ONNX] STFT Support#92087
urinieto wants to merge 12 commits intopytorch:masterfrom
urinieto:stft_onnx2

Conversation

@urinieto
Copy link
Contributor

This PR addresses issue #81075, making torch.stft compatible with ONNX Opset 17's STFT operator.

The conversion works for most of torch.stft functionality:

  • Batched or unbatched inputs
  • Normalization
  • Pre-computed windows
  • Rectangular windows
  • One-sided returns
  • Window centering (implicitly supported)

What is currently not supported is complex types, due to the lack of conversion functionality between PyTorch and ONNX (#86746).

Regardless, this is easy to bypass by setting return_complex=False when using torch.stft.

Note that there is already a draft PR to address this (#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of torch.stft functionality is lacking, and unit tests are missing).

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jan 12, 2023
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jan 12, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: urinieto / name: Oriol Nieto (a95ffd7)

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92087

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 Failures

As of commit b88508c:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@BowenBao BowenBao requested a review from justinchuby January 12, 2023 18:28
@justinchuby justinchuby self-assigned this Jan 12, 2023
@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 13, 2023
@justinchuby
Copy link
Collaborator

Thanks for your contribution! Could you sign the CLA following instructions by the bot comment above, and fix lint issues by running lintrunner -a -m master? I will add comments after the CLA is signed.

https://github.com/pytorch/pytorch/wiki/lintrunner

@justinchuby justinchuby added the module: onnx Related to torch.onnx label Jan 13, 2023
@urinieto
Copy link
Contributor Author

Thanks for reviewing this, @justinchuby! I just signed the CLA, and associated my last commits with the right email address. Moreover, I have passed the linter and fixed a problem with the unit tests regarding the return_complex parameter. Should be good to be reviewed :)

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's still errors in the CLA. Squashing the commits may work?

@urinieto
Copy link
Contributor Author

Apologies, I realized that the CLA should be signed under a "company contribution", since most of this work was done using my company's resources (Adobe). This contribution was internally approved before I submitted this PR, so I might need a working day or two to have someone from the Open Source Office sign the CLA. Will report asap.

@milesial
Copy link
Contributor

Not sure if related to this PR or the underlying onnxruntime, but there seems to be a off-by-one error somewhere when using graph optimization with STFT.

class Test(nn.Module):
    def forward(self, audio):
        stft = torch.stft(audio, 400, 160, return_complex=False)
        return stft

m = Test()
print('torch', m(torch.randn(1, 16000 * 30)).shape)

torch.onnx.export(m,
                  torch.randn(1, 16000 * 30),
                  'test.onnx',
                  export_params=True,
                  input_names=['in'],
                  output_names=['out'],
                  opset_version=17, verbose=False)

sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.optimized_model_filepath = 'test.onnx'
session = rt.InferenceSession('test.onnx', sess_options)
print('onnx unoptimized', session.run(None, {'in': torch.randn(1, 16000 * 30).numpy()})[0].shape)

sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.optimized_model_filepath = 'test.onnx'
session = rt.InferenceSession('test.onnx', sess_options)
print('onnx optimized', session.run(None, {'in': torch.randn(1, 16000 * 30).numpy()})[0].shape)

torch torch.Size([1, 201, 4801, 2])
onnx unoptimized (1, 201, 4801, 2)
onnx optimized (1, 201, 4801, 2)

2023-01-15 19:38:25.256968539 [W:onnxruntime:, execution_frame.cc:828 VerifyOutputSizes] Expected shape from model of {1,201,4802,2} does not match actual shape of {1,201,4801,2} for output out

The output shape matches, but the expected shape is different, and so when we add nodes on top of that it causes shape mismatch errors.

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_complex=False is deprecated so I don't like the idea of forcing people to use it. I see that complex64 and complex128 are mentioned as types in this document, so is it possible to implement view_as_real/view_as_complex in onnx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having view_as_real and view_as_complex ops in onnx. Would you be willing to open an issue for a new operator? https://github.com/onnx/onnx/issues/new?assignees=&labels=operator&template=operator.md&title=

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 I also don't like the idea of forcing return_complex=False, but it is an easy workaround if we don't want to wait to have complex conversion support on ONNX, which seems like it might take quite some time (see below).

@justinchuby Seems like view_as_complex was already requested here: #49793 Unfortunately, there was not enough interest for this when it was reported. I just commented on the issue, to try to bring it back.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in saying onnx Cast doesn't support complex types? If it did then you could do something like:

def _as_real(z):
    return float(z), float(-1j * z)

def _as_complex(real, imag):
    return complex(real) + 1j * complex(imag)

Copy link
Contributor Author

@urinieto urinieto Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct: ONNX doesn't currently support casting for complex types (https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast).

If it did, we could simply use Cast on the result of STFT to make sure the result is returned as complex if return_complex=True.

That being said, would it be possible to obtain the output of STFT, put it on a PyTorch tensor, and convert it to complex using the _as_complex function above, and then put it back into the ONNX Graph?

Sorry, just realized that that's not gonna work, because Cast from/to complex is not supported, which is exactly what @peterbell10 said (😅).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, otherwise there will be a graph break in onnx.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@urinieto
Copy link
Contributor Author

Not sure if related to this PR or the underlying onnxruntime, but there seems to be a off-by-one error somewhere when using graph optimization with STFT.

@milesial I noticed that as well. The fact that it actually returns the correct shape in all the tests I wrote makes me think it might be a bug on the the onnxruntime API (or the ONNX STFT definition?). I couldn't find an issue reported on the onnxruntime repo.

I was thinking of reporting it once this PR is merged, so that it might be easier to reproduce. Or if you wanna report it now, by all means do :)

@milesial
Copy link
Contributor

Got it, good you caught it too, I opened microsoft/onnxruntime#14316 since I don't know if this MR is going to be merged soon

@urinieto
Copy link
Contributor Author

urinieto commented Feb 9, 2023

The CLA is finally signed, apologies it took so long. I'll work on the comments and feedback asap.

@justinchuby
Copy link
Collaborator

I notice the CLA bot is still having errors. Perhaps squash or rebase some commits?

@justinchuby
Copy link
Collaborator

There are some build errors I believe happened because this branch is too old. Could you rebase with master?

@justinchuby justinchuby added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2023
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for creating this!

@justinchuby
Copy link
Collaborator

@pytorchbot merge -g

@justinchuby
Copy link
Collaborator

@pytorchbot merge -g

@justinchuby justinchuby removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2023
@justinchuby
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@justinchuby
Copy link
Collaborator

@pytorchbot merge -f "unrelated cuda and other failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
This PR addresses issue [#81075](pytorch/pytorch#81075),  making `torch.stft` compatible with ONNX Opset 17's STFT operator.

The conversion works for _most_ of `torch.stft` functionality:

- Batched or unbatched inputs
- Normalization
- Pre-computed windows
- Rectangular windows
- One-sided returns
- Window centering (implicitly supported)

What is currently _not_ supported is **complex types**, due to the lack of conversion functionality between PyTorch and ONNX (pytorch/pytorch#86746).

Regardless, this is easy to bypass by setting `return_complex=False` when using `torch.stft`.

Note that there is already a draft PR to address this (pytorch/pytorch#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of `torch.stft` functionality is lacking, and unit tests are missing).
Pull Request resolved: pytorch/pytorch#92087
Approved by: https://github.com/justinchuby
@urinieto
Copy link
Contributor Author

I am trying to export a PyTorch model with STFT into ONNX, but seems like the main torch branch does not support this (any more? yet?). Can anyone explain what happened to this PR? I basically had to repatch this PR to make the export process again. /cc @justinchuby

@justinchuby
Copy link
Collaborator

It should be merged properly? Was the code gone?

@urinieto
Copy link
Contributor Author

urinieto commented Jun 20, 2024

This PR should be able to export to ONNX models with STFT-related methods like the following:

import torch
import torchaudio
 
ONNX_MODEL = "out_melspec.onnx"
N_SAMPLES = 16000
N_FFT = 1024
OPSET_VERSION = 17
 
class CustomMelSpec(torch.nn.Module):
    def forward(self, x):
        return  torchaudio.transforms.MelSpectrogram(n_fft=N_FFT)(x)
 
print("PyTorch Version:", torch.__version__)
print("Torchaudio Version:", torchaudio.__version__)

x = torch.randn(1, 1, N_SAMPLES)
with open(ONNX_MODEL, "wb") as f:
    torch.onnx.export(CustomMelSpec(), (x), f, opset_version=OPSET_VERSION)

But running this script returns the following error (see pytorch and torchaudio versions used below):

PyTorch Version: 2.3.1+cu121
Torchaudio Version: 2.3.1+cu121
/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:307: UserWarning: Constant folding - Only steps=1 can be constant folded fo
r opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at ../torch/csrc/jit/passes/onnx/constant_fold.cpp:179.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
Traceback (most recent call last):
  File "/home/urinieto/export_onnx_test.py", line 20, in <module>
    torch.onnx.export(CustomMelSpec(), (x), f, opset_version=OPSET_VERSION)
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
    graph = _optimize_graph(
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File "/opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/onnx/symbolic_opset17.py", line 115, in stft
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: STFT does not currently support complex types  [Caused by the value '73 defined in (%73 : Float(*, *, strides=[17024, 1], requires_grad=0, device=cpu) = onnx::Reshape[allowzero=0](%63, %72), scope: __main__.CustomMelSpec:: # /opt/conda/envs/audiotagger2/lib/python3.9/site-packages/torch/functional.py:664:0

After applying this PR as a patch again (to torch v1.13.0), the code runs fine and I can successfully export the model. So I've no idea what happened to the merged code 🤔 Maybe the way to export ONNX models has changed?

@urinieto
Copy link
Contributor Author

Any updates on this, @justinchuby ?

@justinchuby
Copy link
Collaborator

I see the same code from this pr in the current codebase. I am wondering why reapplying as a patch would change anything 🤔

@urinieto
Copy link
Contributor Author

Yeah, I have no idea why this is not working right off the bat with the latest torch version (or basically any version after this patch was merged). For whatever reason, the STFT models are not converted to ONNX any more :/

@justinchuby
Copy link
Collaborator

The error message is saying the inputs cannot be complex; we can only run them in the real value mode. Was this what the model was doing?

@urinieto
Copy link
Contributor Author

urinieto commented Jun 24, 2024

After some more digging, turns out the torch code is totally fine, the issue was with torchaudio!

Everything works fine after applying this patch to the latest version of torchaudio (basically forcing torchaudio spectrogram to use float values instead of complex types):

diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py
index af34e707..4eb842c7 100644
--- a/src/torchaudio/functional/functional.py
+++ b/src/torchaudio/functional/functional.py
@@ -133,9 +134,12 @@ def spectrogram(
         pad_mode=pad_mode,
         normalized=frame_length_norm,
         onesided=onesided,
-        return_complex=True,
+        return_complex=False,
     )

+    # From imaginary and real values to absolute value
+    spec_f = torch.sqrt(torch.pow(spec_f[:, :, :, 0], 2.0) + torch.pow(spec_f[:, :, :, 1], 2.0))
+
     # unpack batch
     spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])

Apologies for the confusion, and thanks for helping me figure this out!

Note: the patch above will make non-power spectrograms (ie power=None) not work correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants