Skip to content

[export] Cannot view a tensor with shape torch.Size([1, 512, 32, 128]) and strides (2097152, 128, 65536, 1) as a tensor with shape (1, 512, 4096) #136543

@justinchuby

Description

@justinchuby

Notes

The errors is happening in ExportedProgram.run_decompositions() call: message is Cannot view a tensor with shape torch.Size([1, 512, 32, 128]) and strides (2097152, 128, 65536, 1) as a tensor with shape (1, 512, 4096).

@tugsbayasgalan do you know if anything about the aten.view.default ref or if it has any issues?

# NOTE: shape is a vararg because Tensor.reshape can be called with as
# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
# doesn't support unpacked shapes
# TODO: Turn this into a decomposition (currently fails on reshape meta tests)
@register_decomposition(aten.view.default)
def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
return _reshape_view_helper(a, *shape, allow_copy=False)

PyTorch ONNX Conversion Report

✅ Obtain model graph with `torch.export.export`
⚪ Obtain model graph with `torch.export.export(..., strict=False)`
⚪ Obtain model graph with `torch.jit.trace`
❌ Translate the graph into ONNX
⚪ Run `onnx.checker` on the ONNX model
⚪ Execute the model with ONNX Runtime
⚪ Validate model output accuracy

Error messages

Traceback (most recent call last):

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 1089, in export
    decomposed_program = _prepare_exported_program_for_export(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 727, in _prepare_exported_program_for_export
    exported_program = _fx_passes.decompose_with_registry(exported_program, registry)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_fx_passes.py", line 20, in decompose_with_registry
    return exported_program.run_decompositions(decomp_table)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/export/exported_program.py", line 1080, in run_decompositions
    return _decompose_exported_program(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/export/exported_program.py", line 628, in _decompose_exported_program
    gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/export/exported_program.py", line 421, in _decompose_and_get_gm_with_new_signature_constants
    gm, graph_signature = aot_export_module(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
    flat_f_outs = f(*flat_f_args)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
    tree_out = fn(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 859, in functional_call
    out = PropagateUnbackedSymInts(mod).run(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5619, in run_node
    result = super().run_node(n)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_subclasses/functional_tensor.py", line 535, in __torch_dispatch__
    outs_unwrapped = func._op_dk(

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2039, in _dispatch_impl
    r = func(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_refs/__init__.py", line 4592, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)

  File "/home/xadupre/.local/lib/python3.9/site-packages/torch/_refs/__init__.py", line 3755, in _reshape_view_helper
    raise ValueError(msg)

ValueError: Cannot view a tensor with shape torch.Size([1, 512, 32, 128]) and strides (2097152, 128, 65536, 1) as a tensor with shape (1, 512, 4096)!

While executing %view_4 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%transpose_4, [1, 512, -1]), kwargs = {})
Original traceback:
  File "/home/xadupre/github/private_occ/bash_bench/_bash_bench_model_runner.py", line 204, in __call__
    return self.model(*args, **kwargs)
  File "/home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 670, in forward
    attn_output = attn_output.view(bsz, q_len, -1)

Exported program

...
 # File: /home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:654 in forward, code: value_states = value_states.contiguous()
clone_2: "f16[1, 32, 512, 128]" = torch.ops.aten.clone.default(transpose_3, memory_format = torch.contiguous_format)

 # File: /home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:660 in forward, code: attn_output = torch.nn.functional.scaled_dot_product_attention(
unsqueeze_8: "f16[1, 512, 513]" = torch.ops.aten.unsqueeze.default(mul, 0)
unsqueeze_9: "f16[1, 1, 512, 513]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, 1);  unsqueeze_8 = None
slice_14: "f16[1, 1, 512, 513]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
slice_15: "f16[1, 1, 512, 513]" = torch.ops.aten.slice.Tensor(slice_14, 3, 0, 9223372036854775807);  slice_14 = None
expand_2: "f16[1, 1, 512, 513]" = torch.ops.aten.expand.default(slice_15, [1, 1, -1, -1]);  slice_15 = None
slice_16: "f16[1, 1, 512, 513]" = torch.ops.aten.slice.Tensor(expand_2, 0, 0, 9223372036854775807);  expand_2 = None
slice_17: "f16[1, 1, 512, 513]" = torch.ops.aten.slice.Tensor(slice_16, 1, 0, 9223372036854775807);  slice_16 = None
slice_18: "f16[1, 1, 512, 513]" = torch.ops.aten.slice.Tensor(slice_17, 2, 0, 9223372036854775807);  slice_17 = None
slice_19: "f16[1, 1, 512, 512]" = torch.ops.aten.slice.Tensor(slice_18, 3, 0, 512);  slice_18 = None
scaled_dot_product_attention: "f16[1, 32, 512, 128]" = torch.ops.aten.scaled_dot_product_attention.default(clone, clone_1, clone_2, slice_19);  clone = clone_1 = clone_2 = slice_19 = None

 # File: /home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:669 in forward, code: attn_output = attn_output.transpose(1, 2).contiguous()
transpose_4: "f16[1, 512, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None

 # File: /home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:670 in forward, code: attn_output = attn_output.view(bsz, q_len, -1)
view_4: "f16[1, 512, 4096]" = torch.ops.aten.view.default(transpose_4, [1, 512, -1]);  transpose_4 = None

 # File: /home/xadupre/.local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:672 in forward, code: attn_output = self.o_proj(attn_output)
linear_3: "f16[1, 512, 4096]" = torch.ops.aten.linear.default(view_4, p_model_model_layers_0_self_attn_o_proj_weight);  view_4 = p_model_model_layers_0_self_attn_o_proj_weight = None
...

Analysis

PyTorch ONNX Conversion Analysis

Model Information

The model has 6738546688 parameters and 64 buffers (non-trainable parameters).
Number of parameters per dtype:

defaultdict(<class 'int'>, {torch.float16: 6738546688})

Number of buffers per dtype:

defaultdict(<class 'int'>, {torch.float16: 64})

Of the call_function nodes, the counts of operators used are:

  • aten.slice.Tensor: 323
  • aten.mul.Tensor: 293
  • aten.linear.default: 225
  • aten.add.Tensor: 193
  • aten._to_copy.default: 135
  • aten.unsqueeze.default: 132
  • aten.view.default: 129
  • aten.transpose.int: 128
  • aten.clone.default: 96
  • aten.pow.Tensor_Scalar: 65
  • aten.mean.dim: 65
  • aten.rsqrt.default: 65
  • aten.neg.default: 64
  • aten.cat.default: 64
  • aten.expand.default: 33
  • aten.scaled_dot_product_attention.default: 32
  • aten.silu.default: 32
  • <built-in function getitem>: 2
  • aten.embedding.default: 1
  • aten.arange.start: 1
  • aten.full.default: 1
  • aten.triu.default: 1
  • aten.arange.default: 1
  • aten.gt.Tensor: 1
  • wrap_with_autocast: 1

ONNX Conversion Information

The model contains operators the dispatcher could not find registered ONNX decompositions for. This may be due to missing implementations, decompositions not registered correctly, or a bug in the dispatcher.

Errors grouped by operator:

  • wrap_with_autocast: No decompositions registered for the real-valued input. Example node: %wrap_with_autocast : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_autocast](args = (cuda, None, False, None, %submod_3, %expand_1, %_to_copy_1), kwargs = {}). All nodes: [wrap_with_autocast]

Profiling result


  _     ._   __/__   _ _  _  _ _/_   Recorded: 16:08:25  Samples:  24301
 /_//_/// /_\ / //_// / //_'/ //     Duration: 57.754    CPU time: 57.757
/   _/                      v4.7.3

Profile at /home/xadupre/.local/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py:646

57.753 export  torch/onnx/_internal/exporter/_core.py:932
├─ 36.167 save  torch/export/__init__.py:377
│     [60 frames hidden]  torch, <built-in>, copy, <string>, zi...
└─ 20.468 TorchExportStrategy.__call__  torch/onnx/_internal/exporter/_capture_strategies.py:79
      [215 frames hidden]  torch, <string>
         8.183 _fn  torch/_dynamo/eval_frame.py:431
         └─ 1.191 WrappedModelBase.__call__  bash_bench/_bash_bench_model_runner.py:203
            └─ 1.190 _fn  torch/_dynamo/eval_frame.py:628
                  [11 frames hidden]  torch, <eval_with_key>

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Labels

export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions