Skip to content

Commit 01ae0fb

Browse files
Update softmax.py
1 parent 788d95b commit 01ae0fb

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

maths/softmax.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
import numpy as np
1414

1515

16-
def softmax(vector):
16+
def softmax(vector, axis=-1):
1717
"""
1818
Implements the softmax function
1919
2020
Parameters:
2121
vector (np.array,list,tuple): A numpy array of shape (1,n)
2222
consisting of real values or a similar list,tuple
23-
23+
axis (int, optional): Axis along which to compute softmax. Default is -1.
2424
2525
Returns:
2626
softmax_vec (np.array): The input numpy array after applying
@@ -39,18 +39,40 @@ def softmax(vector):
3939
array([1.])
4040
"""
4141

42-
# Calculate e^x for each x in your vector where e is Euler's
43-
# number (approximately 2.718)
44-
exponent_vector = np.exp(vector)
42+
# Convert input to numpy array of floats
43+
vector = np.asarray(vector, dtype=float)
44+
45+
# Handle empty input
46+
if vector.size == 0:
47+
raise ValueError("softmax input must be non-empty")
48+
49+
# Validate axis
50+
if not (-vector.ndim <= axis < vector.ndim):
51+
raise np.AxisError(f"axis {axis} is out of bounds for array of dimension {vector.ndim}")
52+
53+
# Subtract max for numerical stability
54+
vector_max = np.max(vector, axis=axis, keepdims=True)
55+
exponent_vector = np.exp(vector - vector_max)
4556

46-
# Add up the all the exponentials
47-
sum_of_exponents = np.sum(exponent_vector)
57+
# Sum of exponentials along the axis
58+
sum_of_exponents = np.sum(exponent_vector, axis=axis, keepdims=True)
4859

49-
# Divide every exponent by the sum of all exponents
60+
# Divide each exponent by the sum along the axis
5061
softmax_vector = exponent_vector / sum_of_exponents
5162

5263
return softmax_vector
5364

5465

5566
if __name__ == "__main__":
67+
# Single value
5668
print(softmax((0,)))
69+
70+
# Vector
71+
print(softmax([1, 2, 3]))
72+
73+
# Matrix along last axis
74+
mat = np.array([[1, 2, 3], [4, 5, 6]])
75+
print("Softmax along last axis:\n", softmax(mat))
76+
77+
# Matrix along axis 0
78+
print("Softmax along axis 0:\n", softmax(mat, axis=0))

0 commit comments

Comments
 (0)