Skip to content

Commit c703ea5

Browse files
authored
[HLSL][DirectX][SPIRV] Implement the fma API (#185304)
This PR adds `fma` HLSL intrinsic (with support for matrices) It follows all of the steps from #99117. Closes #99117.
1 parent 3d5a255 commit c703ea5

File tree

10 files changed

+430
-6
lines changed

10 files changed

+430
-6
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13285,6 +13285,9 @@ def err_builtin_invalid_arg_type: Error<
1328513285
"%plural{0:|: }3"
1328613286
"%plural{[0,3]:type|:types}1 (was %4)">;
1328713287

13288+
def err_builtin_requires_double_type: Error<
13289+
"%ordinal0 argument must be a scalar, vector, or matrix of double type (was %1)">;
13290+
1328813291
def err_bswapg_invalid_bit_width : Error<
1328913292
"_BitInt type %0 (%1 bits) must be a multiple of 16 bits for byte swapping">;
1329013293

clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,60 @@ float3 floor(float3);
12351235
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_floor)
12361236
float4 floor(float4);
12371237

1238+
//===----------------------------------------------------------------------===//
1239+
// fused multiply-add builtins
1240+
//===----------------------------------------------------------------------===//
1241+
1242+
/// \fn double fma(double a, double b, double c)
1243+
/// \brief Returns the double-precision fused multiply-addition of a * b + c.
1244+
/// \param a The first value in the fused multiply-addition.
1245+
/// \param b The second value in the fused multiply-addition.
1246+
/// \param c The third value in the fused multiply-addition.
1247+
1248+
// double scalars and vectors
1249+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1250+
double fma(double, double, double);
1251+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1252+
double2 fma(double2, double2, double2);
1253+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1254+
double3 fma(double3, double3, double3);
1255+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1256+
double4 fma(double4, double4, double4);
1257+
1258+
// double matrices
1259+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1260+
double1x1 fma(double1x1, double1x1, double1x1);
1261+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1262+
double1x2 fma(double1x2, double1x2, double1x2);
1263+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1264+
double1x3 fma(double1x3, double1x3, double1x3);
1265+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1266+
double1x4 fma(double1x4, double1x4, double1x4);
1267+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1268+
double2x1 fma(double2x1, double2x1, double2x1);
1269+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1270+
double2x2 fma(double2x2, double2x2, double2x2);
1271+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1272+
double2x3 fma(double2x3, double2x3, double2x3);
1273+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1274+
double2x4 fma(double2x4, double2x4, double2x4);
1275+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1276+
double3x1 fma(double3x1, double3x1, double3x1);
1277+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1278+
double3x2 fma(double3x2, double3x2, double3x2);
1279+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1280+
double3x3 fma(double3x3, double3x3, double3x3);
1281+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1282+
double3x4 fma(double3x4, double3x4, double3x4);
1283+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1284+
double4x1 fma(double4x1, double4x1, double4x1);
1285+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1286+
double4x2 fma(double4x2, double4x2, double4x2);
1287+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1288+
double4x3 fma(double4x3, double4x3, double4x3);
1289+
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
1290+
double4x4 fma(double4x4, double4x4, double4x4);
1291+
12381292
//===----------------------------------------------------------------------===//
12391293
// frac builtins
12401294
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,9 +2178,10 @@ static bool
21782178
checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
21792179
Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
21802180
int ArgOrdinal) {
2181-
QualType EltTy = ArgTy;
2182-
if (auto *VecTy = EltTy->getAs<VectorType>())
2183-
EltTy = VecTy->getElementType();
2181+
clang::QualType EltTy =
2182+
ArgTy->isVectorType() ? ArgTy->getAs<VectorType>()->getElementType()
2183+
: ArgTy->isMatrixType() ? ArgTy->getAs<MatrixType>()->getElementType()
2184+
: ArgTy;
21842185

21852186
switch (ArgTyRestr) {
21862187
case Sema::EltwiseBuiltinArgTyRestriction::None:
@@ -2192,6 +2193,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
21922193
break;
21932194
case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
21942195
if (!EltTy->isRealFloatingType()) {
2196+
// FIXME: make diagnostic's wording correct for matrices
21952197
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
21962198
<< ArgOrdinal << /* scalar or vector */ 5 << /* no int */ 0
21972199
<< /* floating-point */ 1 << ArgTy;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3149,6 +3149,25 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
31493149
return false;
31503150
}
31513151

3152+
static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc,
3153+
int ArgOrdinal,
3154+
clang::QualType PassedType) {
3155+
clang::QualType BaseType =
3156+
PassedType->isVectorType()
3157+
? PassedType->castAs<clang::VectorType>()->getElementType()
3158+
: PassedType->isMatrixType()
3159+
? PassedType->castAs<clang::MatrixType>()->getElementType()
3160+
: PassedType;
3161+
if (!BaseType->isDoubleType()) {
3162+
// FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
3163+
// this custom error.
3164+
return S->Diag(Loc, diag::err_builtin_requires_double_type)
3165+
<< ArgOrdinal << PassedType;
3166+
}
3167+
3168+
return false;
3169+
}
3170+
31523171
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
31533172
unsigned ArgIndex) {
31543173
auto *Arg = TheCall->getArg(ArgIndex);
@@ -4120,6 +4139,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
41204139
TheCall->setType(ArgTyA);
41214140
break;
41224141
}
4142+
case Builtin::BI__builtin_elementwise_fma: {
4143+
if (SemaRef.checkArgCount(TheCall, 3) ||
4144+
CheckAllArgsHaveSameType(&SemaRef, TheCall)) {
4145+
return true;
4146+
}
4147+
4148+
if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
4149+
CheckAnyDoubleRepresentation))
4150+
return true;
4151+
4152+
ExprResult A = TheCall->getArg(0);
4153+
QualType ArgTyA = A.get()->getType();
4154+
// return type is the same as input type
4155+
TheCall->setType(ArgTyA);
4156+
break;
4157+
}
41234158
case Builtin::BI__builtin_hlsl_transpose: {
41244159
if (SemaRef.checkArgCount(TheCall, 1))
41254160
return true;
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm \
3+
// RUN: -disable-llvm-passes -o - | FileCheck %s
4+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
5+
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm \
6+
// RUN: -disable-llvm-passes -o - | FileCheck %s
7+
8+
// CHECK-LABEL: define {{.*}} double @{{.*}}fma_double{{.*}}(
9+
// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.fma.f64(double
10+
// CHECK: ret double
11+
double fma_double(double a, double b, double c) { return fma(a, b, c); }
12+
13+
// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2{{.*}}(
14+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
15+
// CHECK: ret <2 x double>
16+
double2 fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); }
17+
18+
// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3{{.*}}(
19+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
20+
// CHECK: ret <3 x double>
21+
double3 fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); }
22+
23+
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4{{.*}}(
24+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
25+
// CHECK: ret <4 x double>
26+
double4 fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); }
27+
28+
// CHECK-LABEL: define {{.*}} <1 x double> @{{.*}}fma_double1x1{{.*}}(
29+
// CHECK: call reassoc nnan ninf nsz arcp afn <1 x double> @llvm.fma.v1f64(<1 x double>
30+
// CHECK: ret <1 x double>
31+
double1x1 fma_double1x1(double1x1 a, double1x1 b, double1x1 c) {
32+
return fma(a, b, c);
33+
}
34+
35+
// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double1x2{{.*}}(
36+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
37+
// CHECK: ret <2 x double>
38+
double1x2 fma_double1x2(double1x2 a, double1x2 b, double1x2 c) {
39+
return fma(a, b, c);
40+
}
41+
42+
// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double1x3{{.*}}(
43+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
44+
// CHECK: ret <3 x double>
45+
double1x3 fma_double1x3(double1x3 a, double1x3 b, double1x3 c) {
46+
return fma(a, b, c);
47+
}
48+
49+
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double1x4{{.*}}(
50+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
51+
// CHECK: ret <4 x double>
52+
double1x4 fma_double1x4(double1x4 a, double1x4 b, double1x4 c) {
53+
return fma(a, b, c);
54+
}
55+
56+
// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2x1{{.*}}(
57+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
58+
// CHECK: ret <2 x double>
59+
double2x1 fma_double2x1(double2x1 a, double2x1 b, double2x1 c) {
60+
return fma(a, b, c);
61+
}
62+
63+
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double2x2{{.*}}(
64+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
65+
// CHECK: ret <4 x double>
66+
double2x2 fma_double2x2(double2x2 a, double2x2 b, double2x2 c) {
67+
return fma(a, b, c);
68+
}
69+
70+
// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double2x3{{.*}}(
71+
// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double>
72+
// CHECK: ret <6 x double>
73+
double2x3 fma_double2x3(double2x3 a, double2x3 b, double2x3 c) {
74+
return fma(a, b, c);
75+
}
76+
77+
// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double2x4{{.*}}(
78+
// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double>
79+
// CHECK: ret <8 x double>
80+
double2x4 fma_double2x4(double2x4 a, double2x4 b, double2x4 c) {
81+
return fma(a, b, c);
82+
}
83+
84+
// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3x1{{.*}}(
85+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
86+
// CHECK: ret <3 x double>
87+
double3x1 fma_double3x1(double3x1 a, double3x1 b, double3x1 c) {
88+
return fma(a, b, c);
89+
}
90+
91+
// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double3x2{{.*}}(
92+
// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double>
93+
// CHECK: ret <6 x double>
94+
double3x2 fma_double3x2(double3x2 a, double3x2 b, double3x2 c) {
95+
return fma(a, b, c);
96+
}
97+
98+
// CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}fma_double3x3{{.*}}(
99+
// CHECK: call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.fma.v9f64(<9 x double>
100+
// CHECK: ret <9 x double>
101+
double3x3 fma_double3x3(double3x3 a, double3x3 b, double3x3 c) {
102+
return fma(a, b, c);
103+
}
104+
105+
// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double3x4{{.*}}(
106+
// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double>
107+
// CHECK: ret <12 x double>
108+
double3x4 fma_double3x4(double3x4 a, double3x4 b, double3x4 c) {
109+
return fma(a, b, c);
110+
}
111+
112+
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4x1{{.*}}(
113+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
114+
// CHECK: ret <4 x double>
115+
double4x1 fma_double4x1(double4x1 a, double4x1 b, double4x1 c) {
116+
return fma(a, b, c);
117+
}
118+
119+
// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double4x2{{.*}}(
120+
// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double>
121+
// CHECK: ret <8 x double>
122+
double4x2 fma_double4x2(double4x2 a, double4x2 b, double4x2 c) {
123+
return fma(a, b, c);
124+
}
125+
126+
// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double4x3{{.*}}(
127+
// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double>
128+
// CHECK: ret <12 x double>
129+
double4x3 fma_double4x3(double4x3 a, double4x3 b, double4x3 c) {
130+
return fma(a, b, c);
131+
}
132+
133+
// CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}fma_double4x4{{.*}}(
134+
// CHECK: call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.fma.v16f64(<16 x double>
135+
// CHECK: ret <16 x double>
136+
double4x4 fma_double4x4(double4x4 a, double4x4 b, double4x4 c) {
137+
return fma(a, b, c);
138+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \
2+
// RUN: -triple dxil-pc-shadermodel6.6-library %s \
3+
// RUN: -emit-llvm-only -disable-llvm-passes -verify \
4+
// RUN: -verify-ignore-unexpected=note
5+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \
6+
// RUN: -triple spirv-unknown-vulkan-compute %s \
7+
// RUN: -emit-llvm-only -disable-llvm-passes -verify \
8+
// RUN: -verify-ignore-unexpected=note
9+
10+
float bad_float(float a, float b, float c) {
11+
return fma(a, b, c);
12+
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}}
13+
}
14+
15+
float2 bad_float2(float2 a, float2 b, float2 c) {
16+
return fma(a, b, c);
17+
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}}
18+
}
19+
20+
float2x2 bad_float2x2(float2x2 a, float2x2 b, float2x2 c) {
21+
return fma(a, b, c);
22+
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}}
23+
}
24+
25+
half bad_half(half a, half b, half c) {
26+
return fma(a, b, c);
27+
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}}
28+
}
29+
30+
half2 bad_half2(half2 a, half2 b, half2 c) {
31+
return fma(a, b, c);
32+
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector<half, 2>'))}}
33+
}
34+
35+
half2x2 bad_half2x2(half2x2 a, half2x2 b, half2x2 c) {
36+
return fma(a, b, c);
37+
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2x2' (aka 'matrix<half, 2, 2>'))}}
38+
}
39+
40+
double mixed_bad_second(double a, float b, double c) {
41+
return fma(a, b, c);
42+
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
43+
}
44+
45+
double mixed_bad_third(double a, double b, half c) {
46+
return fma(a, b, c);
47+
// expected-error@-1 {{arguments are of different types ('double' vs 'half')}}
48+
}
49+
50+
double2 mixed_bad_second_vec(double2 a, float2 b, double2 c) {
51+
return fma(a, b, c);
52+
// expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}}
53+
}
54+
55+
double2 mixed_bad_third_vec(double2 a, double2 b, float2 c) {
56+
return fma(a, b, c);
57+
// expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}}
58+
}
59+
60+
double2x2 mixed_bad_second_mat(double2x2 a, float2x2 b, double2x2 c) {
61+
return fma(a, b, c);
62+
// expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<float, [2 * ...]>')}}
63+
}
64+
65+
double2x2 mixed_bad_third_mat(double2x2 a, double2x2 b, half2x2 c) {
66+
return fma(a, b, c);
67+
// expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<half, [2 * ...]>')}}
68+
}
69+
70+
double shape_mismatch_second(double a, double2 b, double c) {
71+
return fma(a, b, c);
72+
// expected-error@-1 {{call to 'fma' is ambiguous}}
73+
}
74+
75+
double2 shape_mismatch_third(double2 a, double2 b, double c) {
76+
return fma(a, b, c);
77+
// expected-error@-1 {{call to 'fma' is ambiguous}}
78+
}
79+
80+
double2x2 shape_mismatch_scalar_mat(double2x2 a, double b, double2x2 c) {
81+
return fma(a, b, c);
82+
// expected-error@-1 {{call to 'fma' is ambiguous}}
83+
}
84+
85+
double2x2 shape_mismatch_vec_mat(double2x2 a, double2 b, double2x2 c) {
86+
return fma(a, b, c);
87+
// expected-error@-1 {{arguments are of different types ('double2x2' (aka 'matrix<double, 2, 2>') vs 'double2' (aka 'vector<double, 2>'))}}
88+
}
89+
90+
int bad_int(int a, int b, int c) {
91+
return fma(a, b, c);
92+
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int')}}
93+
}
94+
95+
int2 bad_int2(int2 a, int2 b, int2 c) {
96+
return fma(a, b, c);
97+
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int2' (aka 'vector<int, 2>'))}}
98+
}
99+
100+
bool bad_bool(bool a, bool b, bool c) {
101+
return fma(a, b, c);
102+
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool')}}
103+
}
104+
105+
bool2 bad_bool2(bool2 a, bool2 b, bool2 c) {
106+
return fma(a, b, c);
107+
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2' (aka 'vector<bool, 2>'))}}
108+
}
109+
110+
bool2x2 bad_bool2x2(bool2x2 a, bool2x2 b, bool2x2 c) {
111+
return fma(a, b, c);
112+
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2x2' (aka 'matrix<bool, 2, 2>'))}}
113+
}

0 commit comments

Comments
 (0)