Skip to content

Commit 3dd9579

Browse files
sayantnAmanieu
authored andcommitted
Added a bf16 type
1 parent fe8f300 commit 3dd9579

File tree

5 files changed

+52
-21
lines changed

5 files changed

+52
-21
lines changed

crates/core_arch/src/x86/avx512bf16.rs

+9-8
Original file line numberDiff line numberDiff line change
@@ -486,9 +486,9 @@ pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
486486
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
487487
#[inline]
488488
#[target_feature(enable = "avx512bf16,avx512f")]
489-
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
490-
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
491-
f32::from_bits((a as u32) << 16)
489+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
490+
pub unsafe fn _mm_cvtsbh_ss(a: bf16) -> f32 {
491+
f32::from_bits((a.to_bits() as u32) << 16)
492492
}
493493

494494
/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
@@ -558,9 +558,10 @@ pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh {
558558
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
559559
#[inline]
560560
#[target_feature(enable = "avx512bf16,avx512vl")]
561-
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
562-
pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 {
563-
simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0)
561+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
562+
pub unsafe fn _mm_cvtness_sbh(a: f32) -> bf16 {
563+
let value: u16 = simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0);
564+
bf16::from_bits(value)
564565
}
565566

566567
#[cfg(test)]
@@ -1910,7 +1911,7 @@ mod tests {
19101911

19111912
#[simd_test(enable = "avx512bf16")]
19121913
unsafe fn test_mm_cvtsbh_ss() {
1913-
let r = _mm_cvtsbh_ss(BF16_ONE);
1914+
let r = _mm_cvtsbh_ss(bf16::from_bits(BF16_ONE));
19141915
assert_eq!(r, 1.);
19151916
}
19161917

@@ -1944,6 +1945,6 @@ mod tests {
19441945
#[simd_test(enable = "avx512bf16,avx512vl")]
19451946
unsafe fn test_mm_cvtness_sbh() {
19461947
let r = _mm_cvtness_sbh(1.);
1947-
assert_eq!(r, BF16_ONE);
1948+
assert_eq!(r.to_bits(), BF16_ONE);
19481949
}
19491950
}

crates/core_arch/src/x86/avxneconvert.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::arch::asm;
2-
use crate::core_arch::{simd::*, x86::*};
2+
use crate::core_arch::x86::*;
33

44
#[cfg(test)]
55
use stdarch_test::assert_instr;
@@ -15,9 +15,9 @@ use stdarch_test::assert_instr;
1515
all(test, any(target_os = "linux", target_env = "msvc")),
1616
assert_instr(vbcstnebf162ps)
1717
)]
18-
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
19-
pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
20-
transmute(bcstnebf162ps_128(a))
18+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
19+
pub unsafe fn _mm_bcstnebf16_ps(a: *const bf16) -> __m128 {
20+
bcstnebf162ps_128(a)
2121
}
2222

2323
/// Convert scalar BF16 (16-bit) floating point element stored at memory locations starting at location
@@ -31,9 +31,9 @@ pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
3131
all(test, any(target_os = "linux", target_env = "msvc")),
3232
assert_instr(vbcstnebf162ps)
3333
)]
34-
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
35-
pub unsafe fn _mm256_bcstnebf16_ps(a: *const u16) -> __m256 {
36-
transmute(bcstnebf162ps_256(a))
34+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
35+
pub unsafe fn _mm256_bcstnebf16_ps(a: *const bf16) -> __m256 {
36+
bcstnebf162ps_256(a)
3737
}
3838

3939
/// Convert packed BF16 (16-bit) floating-point even-indexed elements stored at memory locations starting at
@@ -143,9 +143,9 @@ pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh {
143143
#[allow(improper_ctypes)]
144144
extern "C" {
145145
#[link_name = "llvm.x86.vbcstnebf162ps128"]
146-
fn bcstnebf162ps_128(a: *const u16) -> f32x4;
146+
fn bcstnebf162ps_128(a: *const bf16) -> __m128;
147147
#[link_name = "llvm.x86.vbcstnebf162ps256"]
148-
fn bcstnebf162ps_256(a: *const u16) -> f32x8;
148+
fn bcstnebf162ps_256(a: *const bf16) -> __m256;
149149

150150
#[link_name = "llvm.x86.vcvtneebf162ps128"]
151151
fn cvtneebf162ps_128(a: *const __m128bh) -> __m128;
@@ -177,15 +177,15 @@ mod tests {
177177

178178
#[simd_test(enable = "avxneconvert")]
179179
unsafe fn test_mm_bcstnebf16_ps() {
180-
let a = BF16_ONE;
180+
let a = bf16::from_bits(BF16_ONE);
181181
let r = _mm_bcstnebf16_ps(addr_of!(a));
182182
let e = _mm_set_ps(1., 1., 1., 1.);
183183
assert_eq_m128(r, e);
184184
}
185185

186186
#[simd_test(enable = "avxneconvert")]
187187
unsafe fn test_mm256_bcstnebf16_ps() {
188-
let a = BF16_ONE;
188+
let a = bf16::from_bits(BF16_ONE);
189189
let r = _mm256_bcstnebf16_ps(addr_of!(a));
190190
let e = _mm256_set_ps(1., 1., 1., 1., 1., 1., 1., 1.);
191191
assert_eq_m256(r, e);

crates/core_arch/src/x86/mod.rs

+25
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,31 @@ types! {
337337
);
338338
}
339339

340+
/// The BFloat16 type used in AVX-512 intrinsics.
341+
#[repr(transparent)]
342+
#[derive(Copy, Clone, Debug)]
343+
#[allow(non_camel_case_types)]
344+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
345+
pub struct bf16(u16);
346+
347+
impl bf16 {
348+
/// Raw transmutation from `u16`
349+
#[inline]
350+
#[must_use]
351+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
352+
pub const fn from_bits(bits: u16) -> bf16 {
353+
bf16(bits)
354+
}
355+
356+
/// Raw transmutation to `u16`
357+
#[inline]
358+
#[must_use = "this returns the result of the operation, without modifying the original"]
359+
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
360+
pub const fn to_bits(self) -> u16 {
361+
self.0
362+
}
363+
}
364+
340365
/// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer
341366
#[allow(non_camel_case_types)]
342367
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]

crates/stdarch-verify/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
197197
"_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM },
198198
"_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM },
199199
"bool" => quote! { &BOOL },
200+
"bf16" => quote! { &BF16 },
200201
"f32" => quote! { &F32 },
201202
"f64" => quote! { &F64 },
202203
"i16" => quote! { &I16 },

crates/stdarch-verify/tests/x86-intel.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct Function {
2222
has_test: bool,
2323
}
2424

25+
static BF16: Type = Type::BFloat16;
2526
static F32: Type = Type::PrimFloat(32);
2627
static F64: Type = Type::PrimFloat(64);
2728
static I8: Type = Type::PrimSigned(8);
@@ -65,6 +66,7 @@ enum Type {
6566
PrimFloat(u8),
6667
PrimSigned(u8),
6768
PrimUnsigned(u8),
69+
BFloat16,
6870
MutPtr(&'static Type),
6971
ConstPtr(&'static Type),
7072
M128,
@@ -699,7 +701,8 @@ fn equate(
699701
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
700702
(&Type::PrimSigned(64), "__int64" | "long long") => {}
701703
(&Type::PrimUnsigned(8), "unsigned char") => {}
702-
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
704+
(&Type::PrimUnsigned(16), "unsigned short") => {}
705+
(&Type::BFloat16, "__bfloat16") => {}
703706
(
704707
&Type::PrimUnsigned(32),
705708
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",
@@ -758,9 +761,10 @@ fn equate(
758761
(&Type::ConstPtr(&Type::PrimSigned(8)), "char const*") => {}
759762
(&Type::ConstPtr(&Type::PrimSigned(32)), "__int32 const*" | "int const*") => {}
760763
(&Type::ConstPtr(&Type::PrimSigned(64)), "__int64 const*") => {}
761-
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*" | "__bf16 const*") => {}
764+
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*") => {}
762765
(&Type::ConstPtr(&Type::PrimUnsigned(32)), "unsigned int const*") => {}
763766
(&Type::ConstPtr(&Type::PrimUnsigned(64)), "unsigned __int64 const*") => {}
767+
(&Type::ConstPtr(&Type::BFloat16), "__bf16 const*") => {}
764768

765769
(&Type::ConstPtr(&Type::M128), "__m128 const*") => {}
766770
(&Type::ConstPtr(&Type::M128BH), "__m128bh const*") => {}

0 commit comments

Comments
 (0)