Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Mar 9, 2022

What does this PR do?

Fix a CI failure for TFDebertaV2Model, caused by the mistake in TFDebertaV2ConvLayer.

Remark

This test test_inference_no_head also fails with the version in #13120. I think this slow test was not run manually to ensure it works before being merged to master.

Code to demonstrate the issue and the effect of this PR

This is adapted from test_inference_no_head

########## Prep ########## 

import numpy as np
import torch
import tensorflow as tf
from transformers import DebertaV2Model, TFDebertaV2Model

input_ids = np.array([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]], dtype=np.int32)
attention_mask = np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=np.int32)

########## PT ########## 

pt_model = DebertaV2Model.from_pretrained("microsoft/deberta-v2-xlarge")

input_ids_pt = torch.from_numpy(input_ids)
attention_mask_pt = torch.from_numpy(attention_mask)
pt_output = pt_model(input_ids_pt, attention_mask=attention_mask_pt)[0]

# compare the actual values for a slice.
pt_expected_slice = torch.tensor(
    [[[0.2356, 0.1948, 0.0369], [-0.1063, 0.3586, -0.5152], [-0.6399, -0.0259, -0.2525]]]
)
pt_output_slice = pt_output[:, 1:4, 1:4]

pt_slice_diff = np.abs(pt_expected_slice.detach().to("cpu").numpy() - pt_output_slice.detach().to("cpu").numpy())
max_pt_slice_diff = np.amax(pt_slice_diff)

print(f"max_pt_slice_diff = {max_pt_slice_diff}")

########## TF ##########

tf_model = TFDebertaV2Model.from_pretrained("microsoft/deberta-v2-xlarge")

input_ids_tf = tf.constant(input_ids)
attention_mask_tf = tf.constant(attention_mask)
tf_output = tf_model(input_ids_tf, attention_mask=attention_mask_tf)[0]

# compare the actual values for a slice.
tf_expected_slice = tf.constant(
    [[[0.2356, 0.1948, 0.0369], [-0.1063, 0.3586, -0.5152], [-0.6399, -0.0259, -0.2525]]]
)
tf_output_slice = tf_output[:, 1:4, 1:4]

tf_slice_diff = tf_expected_slice.numpy() - tf_output_slice.numpy()
max_tf_slice_diff = np.amax(tf_slice_diff)

print(f"max_tf_slice_diff = {max_tf_slice_diff}")

########## PT-TF ########## 

max_pt_tf_diff = np.amax(np.abs(pt_output.detach().to("cpu").numpy() - tf_output.numpy()))
print(f"maximal pt_tf_diff = {max_pt_tf_diff}")

This scripts gives

Before this PR

max_pt_slice_diff = 5.037523806095123e-05
max_tf_slice_diff = 0.5608187317848206
maximal pt_tf_diff = 5.981985092163086

With this PR:

max_pt_slice_diff = 5.037523806095123e-05
max_tf_slice_diff = 4.8374757170677185e-05
maximal pt_tf_diff = 0.000133514404296875

out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
out = self.dropout(out, training=training)
hidden_states = self.conv_act(out)
out = self.conv_act(out)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main place being fixed

input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)

output_states = output * mask
output_states = output * input_mask
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is the correct logic.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing, LGTM!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 9, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

@ydshieh ydshieh merged commit 2f463ef into huggingface:master Mar 10, 2022
@ydshieh ydshieh deleted the fix_tf_deberta_v2_conv_layer branch March 10, 2022 11:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants