Skip to content

Commit d794576

Browse files
committed
Add activations.mish.
1 parent c72f479 commit d794576

File tree

5 files changed

+37
-1
lines changed

5 files changed

+37
-1
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Operations;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
@@ -50,6 +52,12 @@ public Tensor square(Tensor x, string name = null)
5052
public Tensor sum(Tensor x, Axis? axis = null, string name = null)
5153
=> math_ops.reduce_sum(x, axis: axis, name: name);
5254

55+
public Tensor softplus(Tensor features, string name = null)
56+
=> nn_ops.softplus(features, name: name);
57+
58+
public Tensor tanh(Tensor x, string name = null)
59+
=> math_ops.tanh(x, name: name);
60+
5361
/// <summary>
5462
/// Finds values and indices of the `k` largest entries for the last dimension.
5563
/// </summary>

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Te
120120
};
121121
}
122122

123+
[RegisterGradient("Softplus")]
124+
public static Tensor[] _SoftplusGrad(Operation op, Tensor[] grads)
125+
{
126+
var grad = grads[0];
127+
var x = op.inputs[0];
128+
129+
var softplus = grad * math_ops.sigmoid(x);
130+
return new Tensor[] { softplus };
131+
}
132+
123133
[RegisterGradient("SquaredDifference")]
124134
public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
125135
{

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
132132
return _softmax(logits, gen_nn_ops.softmax, axis, name);
133133
}
134134

135+
public static Tensor softplus(Tensor features, string name = null)
136+
=> tf.Context.ExecuteOp("Softplus", name, new ExecuteOpArgs(features));
137+
135138
public static Tensor l2_loss(Tensor t, string name = null)
136139
=> tf.Context.ExecuteOp("L2Loss", name, new ExecuteOpArgs(t));
137140

src/TensorFlowNET.Keras/Activations.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ public class Activations
2020
=> tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features));
2121
private static Activation _tanh = (features, name)
2222
=> tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
23+
private static Activation _mish = (features, name)
24+
=> features * tf.math.tanh(tf.math.softplus(features));
2325

2426
/// <summary>
2527
/// Register the name-activation mapping in this static class.
2628
/// </summary>
2729
/// <param name="name"></param>
28-
/// <param name="Activation"></param>
30+
/// <param name="activation"></param>
2931
private static void RegisterActivation(string name, Activation activation)
3032
{
3133
_nameActivationMap[name] = activation;
@@ -42,6 +44,7 @@ static Activations()
4244
RegisterActivation("sigmoid", _sigmoid);
4345
RegisterActivation("softmax", _softmax);
4446
RegisterActivation("tanh", _tanh);
47+
RegisterActivation("mish", _mish);
4548
}
4649

4750
public Activation Linear => _linear;
@@ -54,6 +57,7 @@ static Activations()
5457

5558
public Activation Tanh => _tanh;
5659

60+
public Activation Mish => _mish;
5761

5862
public static Activation GetActivationByName(string name)
5963
{

test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,16 @@ public void Swish()
9494
NDArray expected = new NDArray(new float[] { -0.14227762f, -0.23840584f, -0.26894143f, 0f, 0.7310586f, 1.761594f });
9595
Assert.AreEqual(expected, output.numpy());
9696
}
97+
98+
/// <summary>
99+
/// https://www.tensorflow.org/addons/api_docs/python/tfa/activations/mish
100+
/// </summary>
101+
[TestMethod]
102+
public void Mish()
103+
{
104+
var x = tf.constant(new[] { 1.0, 0.0, 1.0 }, dtype: tf.float32);
105+
var output = keras.activations.Mish(x);
106+
Assert.AreEqual(new[] { 0.86509836f, 0f, 0.86509836f }, output.numpy());
107+
}
97108
}
98109
}

0 commit comments

Comments
 (0)