Skip to content

Commit e2b5a46

Browse files
committed
Add mish activation
1 parent 272e20f commit e2b5a46

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

lib/axon.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ defmodule Axon do
615615
end
616616

617617
@activation_layers [:celu, :elu, :exp, :gelu, :hard_sigmoid, :hard_silu, :hard_tanh] ++
618-
[:leaky_relu, :linear, :log_sigmoid, :relu, :relu6] ++
618+
[:leaky_relu, :linear, :log_sigmoid, :mish, :relu, :relu6] ++
619619
[:sigmoid, :silu, :selu, :softmax, :softplus, :softsign, :tanh]
620620

621621
@doc """

lib/axon/activations.ex

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,32 @@ defmodule Axon.Activations do
377377
"""
378378
defn log_sigmoid(x), do: -softplus(-x)
379379

380+
@doc ~S"""
381+
Mish activation.
382+
383+
$$f(x_i) = x_i* \tanh(\log(1 + e^x_i))$$
384+
385+
## Examples
386+
387+
iex> Axon.Activations.mish(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], type: {:f, 32}, names: [:data]))
388+
#Nx.Tensor<
389+
f32[data: 7]
390+
[-0.14564745128154755, -0.2525014877319336, -0.30340147018432617, 0.0, 0.8650984168052673, 1.9439589977264404, 2.98653507232666]
391+
>
392+
393+
iex> Axon.Activations.mish(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
394+
#Nx.Tensor<
395+
bf16[batch: 2][data: 3]
396+
[
397+
[-0.30078125, -0.25, -0.1435546875],
398+
[0.86328125, 1.9375, 2.96875]
399+
]
400+
>
401+
"""
402+
defn mish(x) do
403+
x * tanh(softplus(x))
404+
end
405+
380406
@doc ~S"""
381407
Rectified linear unit activation.
382408

lib/axon/compiler.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ defmodule Axon.Compiler do
217217
## Activation Layers
218218

219219
@activation_layers [:celu, :elu, :exp, :gelu, :hard_sigmoid, :hard_silu, :hard_tanh] ++
220-
[:leaky_relu, :linear, :log_sigmoid, :relu, :relu6] ++
220+
[:leaky_relu, :linear, :log_sigmoid, :mish, :relu, :relu6] ++
221221
[:sigmoid, :silu, :selu, :softmax, :softplus, :softsign, :tanh]
222222

223223
defp recur_predict_fun(%Axon{op: op, parent: parent}, cache, param_map, input_map)

0 commit comments

Comments
 (0)