Skip to content

Cannot run the lung_nodule_ct_detection TensorRT model end-to-end #6330

@binliunls

Description

@binliunls

Is your feature request related to a problem? Please describe.
When converting a PyTorch model to a TensorRT model, there are two ways, which are Torch-TensorRT way and ONNX-TensorRT way. When using the ONNX-TensorRT way, it could be set to either use a TorchScript module or PyTorch module as the input.

The Torch-TensorRT way of conversion suffers a slowdown issue currently.

The output of the lung_nodule_ct_detection model is a python dict, which is not fully supported by the ONNX-TensorRT way. This way exports the pytorch model to an ONNX model first, then converts the ONNX model to a TensorRT engine and wraps the TensorRT engine back to a torchscript at the end.

For the first step which exports pytorch models to ONNX models, there are two ways. The first way directly exports the pytorch model with the torch.onnx.export API. The second way is firstly export the pytorch model to a torchscript model using torch.jit.script or torch.jit.trace and then pass the torchscript model to the torch.onnx.export API to get an onnx model, which is used by the trt_export.

If using the first way to exporting the onnx model, the output of the converted model would be a list of Tensor. If using the second way, this model would fail at torch.jit.trace phase or the torch.jit.onnx API would report an error if using torch.jit.script to trace the model. Neither way could make the model run an end-to-end inference, since it needs to feed the output to a detector head which takes a dict as input.

Describe the solution you'd like

  1. Add a control parameter like use_tuple_output to the init function of RetinaNet to control whether to use a dict as ouput or a list of tensor as output. Please note to add it to the __init__ function and use it in the foward function so that it would be convenient for bundle.
  2. The outputs of foward function should be a tuple of tensors or a dict decided by the control parameter above.
  3. Remove this check and add "spatial_dims", "num_classes", "cls_key", "box_reg_key", "num_anchors", "size_divisible" as input parameters to the init function of RetinaNetDetector. And assign the attributes of the class with the input values if the self.network doesn't have these attribute.
  4. Rewrite this logic to parse the output list like the code below. (Just for reference)
head_outputs = network(images)
    if isinstance(head_outputs, (tuple, list)):
        tmp_dict = {}
        tmp_dict["classification"] = head_outputs[:3]
        tmp_dict["box_regression"] = head_outputs[3:]
        head_outputs = tmp_dict
    else:
        ensure_dict_value_to_list_(head_outputs, keys)

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

Additional context
Add any other context or screenshots about the feature request here.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions