Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Feb 1, 2022

What does this PR do?

This PR standardizes the model outputs for semantic segmentation models and creates an AutoModelForSemanticSegmentation class. As discussed internally, models for semantic segmentation should return logits of shape batch_size, num_labels, height, width, one logit per pixel.

Breaking change: The BeitForSemanticSegmentation and SegformerForSemanticSegmentation models have logits with the same height and width as the input after this PR (instead of height/4 and width/4).

To maintain some level of backward compatibility, the SemanticSegmentationModelOutput has a field legacy_logits that users can pick to get the old logits value.

Another possible road to be less breaking is to create new classes BeitForPixelClassification and SegformerForPixelClassification while deprecating the current ones, then name "instance segmentation" "pixel classification" everywhere. It has the benefit of being more understandable to the beginner and look like our "ForTokenClassification" classes but it might surprise an expert more used to the "semantic segmentation" name.

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Feb 1, 2022

The documentation is not available anymore as the PR was closed or merged.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also want to tackle regression tasks with this output class?

Depth estimation is an example of per-pixel regression, but that's a different task.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add it to the docstring once it's supported. For now, the models don't use problem_type to infer the proper loss function.

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me (wondering why the PR is called "standardize instance segmentation models outputs"?).

  • xxxForPixelClassification => both semantic and panoptic segmentation are pixel classification tasks, right? Although in semantic segmentation, one only predicts a single label per pixel, whereas one predicts two labels per pixel in case of panoptic segmentation. If we're going with xxxForPixelClassification and xxxForPanopticSegmentation, people would be confused and that's not really consistent. So I would stick to xxxForSemanticSegmentation.

There's also depth estimation, which is per-pixel regression (if I understand correctly, I'm not entirely familiar with it). Would we then have xxxForPixelRegression? However, currently we also tackle both regression and classification with xxxForSequenceClassification models.

Comment on lines 827 to 829
legacy_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height/4, width/4)`):
For backward compatibility, old logits returned by `BeitForSemanticSegmentation` and
`SegformerForSemanticSegmentation`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should keep this. We could also call it raw_logits (which are the raw logits coming out of the model without any interpolation).

Copy link
Contributor

@FrancescoSaverioZuppichini FrancescoSaverioZuppichini Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a doubt about the naming. logits is usually used to indicate un-normalized vectors used to compute a downstream task (e.g. classification after a softmax layer). If that is the case, return this output class is unfeasible for some models, e.g. MaskFormer.

In MaskFormer we compute the final segmentation masks using two vectors

def forward(self, *args, **kwargs):
        outputs: MaskFormerOutput = self.model(*args, **kwargs)
        # mask classes has shape [BATCH, QUERIES, CLASSES + 1]
        # remove the null class `[..., :-1]`
        masks_classes: Tensor = outputs.preds_logits.softmax(dim=-1)[..., :-1]
        # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
        masks_probs: Tensor = outputs.preds_masks.sigmoid()
        # now we want to sum over the queries,
        # $ out_{c,h,w} =  \sum_q p_{q,c} * m_{q,h,w} $
        # where $ softmax(p) \in R^{q, c} $ is the mask classes
        # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
        # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
        segmentation: Tensor = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)

        return MaskFormerForSemanticSegmentationOutput(segmentation=segmentation, **outputs)

Returning a vector logits of shape batch_size, config.num_labels, height, width) is impossible due. I see two possible solutions

@NielsRogge opened an issue to ask MaskFormer author if is possible to return logits from their model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might have to treat MaskFormer with an exception (in the sense we will need the raw output to have both logits and prediction masks to be able to compute the segmentation) the same way we handle the model for QA with beam search in XLNert. That class would then not be in the AutoModelForSemanticSegmentation as it doesn't have the same API, but we would need one version of MaskFormerForSemanticSegmentation with the same API (it might perform as well as this fancy one with some fine-tuning).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding legacy_logits, I wonder if it wouldn't be better, API-wise, to have the "old" logits returned according to a flag legacy_output or something similar. I understand the need for the breaking change, but I think it would be simpler on users if they didn't have to update their model outputs, but update their model instantiation.

I think it's easier as it doesn't require diving into their previously working code to see what needs to be changed, they would just need to update the config init/model init.

The model would then return a SequenceClassifierOutput with the previous values when instantiated with legacy_output, and would return the new, better SemanticSegmentationModelOutput when nothing is set.

I would also output a FutureWarning when users either instantiate the model, or during their first forward pass. I think the latter is best, as if they have not updated their code and it crashes, the warning will be just above their crash.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've added a comment about logits in the output and why it may be challenging for MaskFormer to use it

Comment on lines 827 to 829
legacy_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height/4, width/4)`):
For backward compatibility, old logits returned by `BeitForSemanticSegmentation` and
`SegformerForSemanticSegmentation`.
Copy link
Contributor

@FrancescoSaverioZuppichini FrancescoSaverioZuppichini Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a doubt about the naming. logits is usually used to indicate un-normalized vectors used to compute a downstream task (e.g. classification after a softmax layer). If that is the case, return this output class is unfeasible for some models, e.g. MaskFormer.

In MaskFormer we compute the final segmentation masks using two vectors

def forward(self, *args, **kwargs):
        outputs: MaskFormerOutput = self.model(*args, **kwargs)
        # mask classes has shape [BATCH, QUERIES, CLASSES + 1]
        # remove the null class `[..., :-1]`
        masks_classes: Tensor = outputs.preds_logits.softmax(dim=-1)[..., :-1]
        # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
        masks_probs: Tensor = outputs.preds_masks.sigmoid()
        # now we want to sum over the queries,
        # $ out_{c,h,w} =  \sum_q p_{q,c} * m_{q,h,w} $
        # where $ softmax(p) \in R^{q, c} $ is the mask classes
        # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
        # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
        segmentation: Tensor = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)

        return MaskFormerForSemanticSegmentationOutput(segmentation=segmentation, **outputs)

Returning a vector logits of shape batch_size, config.num_labels, height, width) is impossible due. I see two possible solutions

@NielsRogge opened an issue to ask MaskFormer author if is possible to return logits from their model

@sgugger sgugger changed the title Standardize instance segmentation models outputs Standardize semantic segmentation models outputs Feb 3, 2022
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have sufficient experience with computer vision and these specific tasks to review the content of the PR, so my review is on the:

  • coherence of this auto model with the rest of the auto models
  • backwards compatibility

I think this PR is great as it's a good opportunity to observe if our "experimental" approach works or not. I think it shows that it lacks quite a few aspects that would make it robust:

  • We do not systematically mention that models are experimental on their modeling page, which is an issue. It should be very visible from the users.
  • This is an urgent refactor, but ideally this should go through a correct deprecation cycle. I think we need to be clear that this is an exceptional instance where we do not respect backwards compatibility for experimental reasons.
  • In order to patch this, I think we'll need to:
    • Tweet (can be from personal accounts) about the update
    • Create a forum post
    • Create a pinned github issue

Thanks for working on this, @sgugger!

Comment on lines 827 to 829
legacy_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height/4, width/4)`):
For backward compatibility, old logits returned by `BeitForSemanticSegmentation` and
`SegformerForSemanticSegmentation`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding legacy_logits, I wonder if it wouldn't be better, API-wise, to have the "old" logits returned according to a flag legacy_output or something similar. I understand the need for the breaking change, but I think it would be simpler on users if they didn't have to update their model outputs, but update their model instantiation.

I think it's easier as it doesn't require diving into their previously working code to see what needs to be changed, they would just need to update the config init/model init.

The model would then return a SequenceClassifierOutput with the previous values when instantiated with legacy_output, and would return the new, better SemanticSegmentationModelOutput when nothing is set.

I would also output a FutureWarning when users either instantiate the model, or during their first forward pass. I think the latter is best, as if they have not updated their code and it crashes, the warning will be just above their crash.

from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_outputs import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, BEIT does not have any "experimental" mention on its model page, so users are not aware that its API is in an experimental state.

I would advocate to add it to the templates by default (I can take care of that), and to set a reminder for 2 months when merging a model to consider removing the experimental mention.

from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, SemanticSegmentationModelOutput, SequenceClassifierOutput
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for segformer, it lacks an experimental mention on the model page.

Comment on lines +287 to +294
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
("beit", "BeitForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
]
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating, now it looks good to me from a backwards compatibility perspective.

@sgugger sgugger merged commit ac6aa10 into master Feb 4, 2022
@sgugger sgugger deleted the pixel_classification_models branch February 4, 2022 19:52
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Feb 18, 2022
* Standardize instance segmentation models outputs

* Rename output

* Update src/transformers/modeling_outputs.py

Co-authored-by: NielsRogge <[email protected]>

* Add legacy argument to the config and model forward

* Update src/transformers/models/beit/modeling_beit.py

Co-authored-by: Lysandre Debut <[email protected]>

* Copy fix in Segformer

Co-authored-by: NielsRogge <[email protected]>
Co-authored-by: Lysandre Debut <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants