Conversation
Timoeller
left a comment
There was a problem hiding this comment.
I made some suggestions in the code
| logit_input = logits is not None | ||
| preds_input = preds is not None | ||
|
|
||
| if logit_input and preds_input: |
There was a problem hiding this comment.
syntax seems confusing:
how about
if logits and preds: logger.warning("Both logits and preds have been passed as input to the TextClassificationHead")
if (logits is None) and (preds is None): logger.error("Neither logits nor preds have been passed as input to the TextClassificationHead")
| (f"Label_tensor_names are missing inside the {head.task_name} Prediction Head. Did you connect the model" | ||
| " with the processor through either 'model.connect_heads_with_processor(processor.tasks)'" | ||
| " or by passing the processor to the Adaptive Model?") | ||
| if not hasattr(head, "label_tensor_name"): |
There was a problem hiding this comment.
Not sure about this one. Can the model work without this? e.g. if it is in inference mode.
If so we should remove the assert, if not and the functionality must break further downstream, then lets keep the assert here
| if logit_input: | ||
| logger.warning("QuestionAnsweringHead.formatted_preds() received logit input when it only expects pred input") | ||
| if not preds_input: | ||
| logger.warning("QuestionAnsweringHead.formatted_preds() did not receive the preds input it expects") |
There was a problem hiding this comment.
this should be at least an logger.error or even kept as assert, since the app breaks without preds
| # are prediction spans | ||
| preds_d = self.aggregate_preds(preds, passage_start_t, ids, seq_2_start_t) | ||
|
|
||
| assert len(preds_d) == len(baskets) |
There was a problem hiding this comment.
we should through a logger error here, since we need as many preds_d (on document level) as we have baskets.
If that is not the case it isnt necessarily breaking a cosuming app. Example: haystack retrieves 10 docs, one of those documents is malformatted so it cannot contain preds. Still we want to give back preds_d on all other 9 docs so the reader produces some answer
| else: | ||
| start_of_word.append(False) | ||
|
|
||
| assert len(tokens) == len(token_offsets) == len(start_of_word) |
There was a problem hiding this comment.
here we either need tests or keep the assert?
|
|
||
| # This fn is used to align QA output of len=n_docs and Classification output of len=n_passages | ||
| def chunk(iterable, lengths): | ||
| assert sum(lengths) == len(iterable) |
There was a problem hiding this comment.
I do not know what this assert does... lets discuss in detail or throw a logger.error?
|
See #468 for how these changes were actually implemented |
The idea of this PR generally is to remove assertion statement so that running systems do not get interrupted by errors. What used to be asserts will, where appropriate, be converted into logging.error() or logging.warning().
This PR focuses only on QA related components. Asserts may still be present if they can be caught by try catch statements such as Processor._featurize_samples() or Processor._init_samples_in_baskets().