Skip to content

Commit b04c4e0

Browse files
Flush denormals to +/- 0 when converting float to bfloat16.
PiperOrigin-RevId: 301948798 Change-Id: Ic24b699b2e23683d3710d7abb4317833df252af0
1 parent c75cfa9 commit b04c4e0

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

tensorflow/core/framework/bfloat16_test.cc

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,35 @@ limitations under the License.
2323
namespace tensorflow {
2424
namespace {
2525

26+
TEST(Bfloat16Test, ZeroRepresentations) {
27+
ASSERT_EQ(bfloat16{0.0f}, bfloat16{0.0f});
28+
ASSERT_EQ(bfloat16{-0.0f}, bfloat16{0.0f});
29+
ASSERT_EQ(bfloat16{-0.0f}, bfloat16{-0.0f});
30+
ASSERT_EQ(bfloat16{0.0f}.value, 0x0000);
31+
ASSERT_EQ(bfloat16{-0.0f}.value, 0x8000);
32+
}
33+
34+
TEST(Bfloat16Test, FlushDenormalsToZero) {
35+
for (float denorm = -std::numeric_limits<float>::denorm_min();
36+
denorm < std::numeric_limits<float>::denorm_min();
37+
denorm = std::nextafterf(denorm, 1.0f)) {
38+
bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm);
39+
ASSERT_EQ(float{bf_trunc}, 0.0f);
40+
if (std::signbit(denorm)) {
41+
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
42+
} else {
43+
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
44+
}
45+
bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm);
46+
ASSERT_EQ(float{bf_round}, 0.0f);
47+
if (std::signbit(denorm)) {
48+
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
49+
} else {
50+
ASSERT_EQ(bf_round.value, 0x0000) << denorm;
51+
}
52+
}
53+
}
54+
2655
TEST(Bfloat16Test, DefaultValueIsZero) {
2756
EXPECT_EQ(0.0f, static_cast<float>(bfloat16()));
2857
}
@@ -65,6 +94,7 @@ TEST_P(Bfloat16Test, TruncateTest) {
6594
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
6695
return;
6796
}
97+
6898
EXPECT_EQ(GetParam().expected_truncation, float(truncated));
6999

70100
bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
@@ -114,14 +144,16 @@ INSTANTIATE_TEST_SUITE_P(
114144
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
115145
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
116146
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
147+
// The following two floats are denormals and will be flushed
148+
// to zero.
117149
Bfloat16TestParam{
118150
BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
119-
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000),
120-
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
151+
BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000),
152+
BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000)},
121153
Bfloat16TestParam{
122154
BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
123-
BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000),
124-
BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
155+
BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000),
156+
BinaryToFloat(0, 0b00000000, 0b0000000, 0b0000000000000000)}));
125157

126158
TEST(Bfloat16Test, Conversion) {
127159
float a[100];

tensorflow/core/lib/bfloat16/bfloat16.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <cmath>
2020
#include <complex>
2121
#include <iostream>
22+
#include <limits>
2223

2324
#include "tensorflow/core/platform/byte_order.h"
2425

@@ -53,6 +54,10 @@ struct bfloat16 {
5354
if (float_isnan(v)) {
5455
output.value = NAN_VALUE;
5556
return output;
57+
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
58+
// Flush denormal to +/- 0.
59+
output.value = std::signbit(v) ? 0x8000 : 0;
60+
return output;
5661
}
5762
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
5863
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
@@ -196,6 +201,9 @@ struct bfloat16 {
196201
// qNaN magic: All exponent bits set + most significant bit of fraction
197202
// set.
198203
output.value = 0x7fc0;
204+
} else if (std::fabs(v) < std::numeric_limits<float>::min()) {
205+
// Flush denormal to +/- 0.0
206+
output.value = std::signbit(v) ? 0x8000 : 0;
199207
} else {
200208
// Fast rounding algorithm that rounds a half value to nearest even. This
201209
// reduces expected error when we convert a large number of floats. Here

0 commit comments

Comments
 (0)