Skip to content

Commit eda158f

Browse files
committed
Revive the threshold for Metal FP32 mish
1 parent 0ceb90f commit eda158f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

cpp/neuralnet/metalbackend.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,17 @@ extension MPSGraph {
8080
assert(tensor.dataType == .float32)
8181

8282
let one = 1.0
83+
let threshold = 10.39
84+
let thresholdTensor = constant(threshold, dataType: tensor.dataType)
85+
let minimumTensor = minimum(tensor, thresholdTensor, name: nil)
86+
let expTensor = exponent(with: minimumTensor, name: nil)
8387
let oneTensor = constant(one, dataType: tensor.dataType)
84-
let expTensor = exponent(with: tensor, name: nil)
8588
let addTensor = addition(expTensor, oneTensor, name: nil)
8689
let logTensor = logarithm(with: addTensor, name: nil)
87-
let tanhTensor = tanh(with: logTensor, name: nil)
90+
let lessTensor = lessThan(tensor, thresholdTensor, name: nil)
91+
let selectTensor = select(
92+
predicate: lessTensor, trueTensor: logTensor, falseTensor: tensor, name: nil)
93+
let tanhTensor = tanh(with: selectTensor, name: nil)
8894
let mulTensor = multiplication(tensor, tanhTensor, name: nil)
8995

9096
return mulTensor

0 commit comments

Comments
 (0)