-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: aotinductoraot inductoraot inductormodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesFor torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesoncall: exportoncall: pt2
Description
🐛 Describe the bug
Hi, I am trying to use AOTI minifier to debug an FP8 problem. But got an error due to _TORCH_TO_SERIALIZE_DTYPE not supporting FP8 dtypes.
repro:
import torch
from torch._inductor import config as inductor_config
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, y):
x = self.fc1(x)
x = self.relu(x)
y = y.to(torch.float8_e4m3fn)
x = self.sigmoid(x)
return x, y
inductor_config.aot_inductor.dump_aoti_minifier = True
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"
with torch.no_grad():
model = Model().to("cuda")
example_inputs = (torch.randn(8, 10).to("cuda"), torch.randn(8, 10).to("cuda"),)
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(ep)
compiled_model = torch._inductor.aoti_load_package(package_path)
result = compiled_model(*example_inputs)
error:
Traceback (most recent call last):
/torch/_export/serde/serialize.py", line 1321, in serialize_graph
getattr(self, f"handle_{node.op}")(node)
/torch/_export/serde/serialize.py", line 528, in handle_call_function
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
/torch/_export/serde/serialize.py", line 663, in serialize_inputs
arg=self.serialize_input(kwargs[schema_arg.name], schema_arg.type),
/torch/_export/serde/serialize.py", line 903, in serialize_input
return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg])
KeyError: torch.float8_e4m3fn
Versions
trunk
cc @chauhang @penguinwu @yanbing-j @vkuzo @albanD @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78
Metadata
Metadata
Assignees
Labels
module: aotinductoraot inductoraot inductormodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesFor torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesoncall: exportoncall: pt2