Skip to content

Commit 35ad402

Browse files
Update softmax.py
1 parent 02fad70 commit 35ad402

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

maths/softmax.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
"""
22
This script demonstrates the implementation of the Softmax function.
33
4-
Its a function that takes as input a vector of K real numbers, and normalizes
5-
it into a probability distribution consisting of K probabilities proportional
6-
to the exponentials of the input numbers. After softmax, the elements of the
7-
vector always sum up to 1.
4+
It takes as input a vector of K real numbers and normalizes it into a
5+
probability distribution consisting of K probabilities proportional
6+
to the exponentials of the input numbers. After applying softmax,
7+
the elements of the vector always sum up to 1.
88
9-
Script inspired from its corresponding Wikipedia article
9+
Script inspired by its corresponding Wikipedia article:
1010
https://en.wikipedia.org/wiki/Softmax_function
1111
"""
1212

1313
import numpy as np
14+
from numpy.exceptions import AxisError
1415

1516

16-
def softmax(vector, axis=-1):
17+
def softmax(vector: np.ndarray, axis: int = -1) -> np.ndarray:
1718
"""
18-
Implements the softmax function
19+
Implements the softmax function.
1920
2021
Parameters:
21-
vector (np.array,list,tuple): A numpy array of shape (1,n)
22-
consisting of real values or a similar list,tuple
23-
axis (int, optional): Axis along which to compute softmax. Default is -1.
22+
vector (np.ndarray | list | tuple): A numpy array of shape (1, n)
23+
consisting of real values or a similar list/tuple.
24+
axis (int, optional): Axis along which to compute softmax.
25+
Default is -1.
2426
2527
Returns:
26-
softmax_vec (np.array): The input numpy array after applying
27-
softmax.
28+
np.ndarray: The input numpy array after applying softmax.
2829
29-
The softmax vector adds up to one. We need to ceil to mitigate for
30-
precision
31-
>>> float(np.ceil(np.sum(softmax([1,2,3,4]))))
30+
The softmax vector adds up to one. We need to ceil to mitigate precision.
31+
32+
>>> float(np.ceil(np.sum(softmax([1, 2, 3, 4]))))
3233
1.0
3334
34-
>>> vec = np.array([5,5])
35+
>>> vec = np.array([5, 5])
3536
>>> softmax(vec)
3637
array([0.5, 0.5])
3738
3839
>>> softmax([0])
3940
array([1.])
4041
"""
41-
4242
# Convert input to numpy array of floats
4343
vector = np.asarray(vector, dtype=float)
4444

@@ -47,11 +47,10 @@ def softmax(vector, axis=-1):
4747
raise ValueError("softmax input must be non-empty")
4848

4949
# Validate axis
50-
if not (-vector.ndim <= axis < vector.ndim):
51-
raise np.AxisError(
52-
f"axis {axis} is out of bounds for array of dimension {vector.ndim}"
53-
)
54-
50+
ndim = vector.ndim
51+
if axis >= ndim or axis < -ndim:
52+
error_message = f"axis {axis} is out of bounds for array of dimension {ndim}"
53+
raise AxisError(error_message)
5554
# Subtract max for numerical stability
5655
vector_max = np.max(vector, axis=axis, keepdims=True)
5756
exponent_vector = np.exp(vector - vector_max)
@@ -61,20 +60,15 @@ def softmax(vector, axis=-1):
6160

6261
# Divide each exponent by the sum along the axis
6362
softmax_vector = exponent_vector / sum_of_exponents
64-
6563
return softmax_vector
6664

67-
6865
if __name__ == "__main__":
6966
# Single value
7067
print(softmax((0,)))
71-
7268
# Vector
7369
print(softmax([1, 2, 3]))
74-
7570
# Matrix along last axis
7671
mat = np.array([[1, 2, 3], [4, 5, 6]])
7772
print("Softmax along last axis:\n", softmax(mat))
78-
7973
# Matrix along axis 0
8074
print("Softmax along axis 0:\n", softmax(mat, axis=0))

0 commit comments

Comments
 (0)