Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit d1b230e

Browse files
committed
Issue python#29282: add fused multiply-add function, math.fma.
1 parent 502efda commit d1b230e

File tree

6 files changed

+338
-1
lines changed

6 files changed

+338
-1
lines changed

Doc/library/math.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ Number-theoretic and representation functions
5757
If *x* is not a float, delegates to ``x.__floor__()``, which should return an
5858
:class:`~numbers.Integral` value.
5959

60+
.. function:: fma(x, y, z)
61+
62+
Fused multiply-add operation. Return ``(x * y) + z``, computed as though with
63+
infinite precision and range followed by a single round to the ``float``
64+
format. This operation often provides better accuracy than the direct
65+
expression ``(x * y) + z``.
66+
67+
This function follows the specification of the fusedMultiplyAdd operation
68+
described in the IEEE 754-2008 standard. The standard leaves one case
69+
implementation-defined, namely the result of ``fma(0, inf, nan)``
70+
and ``fma(inf, 0, nan)``. In these cases, ``math.fma`` returns a NaN,
71+
and does not raise any exception.
72+
73+
.. versionadded:: 3.7
74+
6075

6176
.. function:: fmod(x, y)
6277

Doc/whatsnew/3.7.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ The :const:`~unittest.mock.sentinel` attributes now preserve their identity
100100
when they are :mod:`copied <copy>` or :mod:`pickled <pickle>`.
101101
(Contributed by Serhiy Storchaka in :issue:`20804`.)
102102

103+
math module
104+
-----------
105+
106+
A new function :func:`~math.fma` for fused multiply-add operations has been
107+
added. This function computes ``x * y + z`` with only a single round, and so
108+
avoids any intermediate loss of precision. It wraps the ``fma`` function
109+
provided by C99, and follows the specification of the IEEE 754-2008
110+
"fusedMultiplyAdd" operation for special cases.
111+
103112

104113
Optimizations
105114
=============

Lib/test/test_math.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from test.support import run_unittest, verbose, requires_IEEE_754
55
from test import support
66
import unittest
7+
import itertools
78
import math
89
import os
910
import platform
@@ -1410,11 +1411,244 @@ def test_fractions(self):
14101411
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
14111412

14121413

1414+
class FMATests(unittest.TestCase):
1415+
""" Tests for math.fma. """
1416+
1417+
def test_fma_nan_results(self):
1418+
# Selected representative values.
1419+
values = [
1420+
-math.inf, -1e300, -2.3, -1e-300, -0.0,
1421+
0.0, 1e-300, 2.3, 1e300, math.inf, math.nan
1422+
]
1423+
1424+
# If any input is a NaN, the result should be a NaN, too.
1425+
for a, b in itertools.product(values, repeat=2):
1426+
self.assertIsNaN(math.fma(math.nan, a, b))
1427+
self.assertIsNaN(math.fma(a, math.nan, b))
1428+
self.assertIsNaN(math.fma(a, b, math.nan))
1429+
1430+
def test_fma_infinities(self):
1431+
# Cases involving infinite inputs or results.
1432+
positives = [1e-300, 2.3, 1e300, math.inf]
1433+
finites = [-1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300]
1434+
non_nans = [-math.inf, -2.3, -0.0, 0.0, 2.3, math.inf]
1435+
1436+
# ValueError due to inf * 0 computation.
1437+
for c in non_nans:
1438+
for infinity in [math.inf, -math.inf]:
1439+
for zero in [0.0, -0.0]:
1440+
with self.assertRaises(ValueError):
1441+
math.fma(infinity, zero, c)
1442+
with self.assertRaises(ValueError):
1443+
math.fma(zero, infinity, c)
1444+
1445+
# ValueError when a*b and c both infinite of opposite signs.
1446+
for b in positives:
1447+
with self.assertRaises(ValueError):
1448+
math.fma(math.inf, b, -math.inf)
1449+
with self.assertRaises(ValueError):
1450+
math.fma(math.inf, -b, math.inf)
1451+
with self.assertRaises(ValueError):
1452+
math.fma(-math.inf, -b, -math.inf)
1453+
with self.assertRaises(ValueError):
1454+
math.fma(-math.inf, b, math.inf)
1455+
with self.assertRaises(ValueError):
1456+
math.fma(b, math.inf, -math.inf)
1457+
with self.assertRaises(ValueError):
1458+
math.fma(-b, math.inf, math.inf)
1459+
with self.assertRaises(ValueError):
1460+
math.fma(-b, -math.inf, -math.inf)
1461+
with self.assertRaises(ValueError):
1462+
math.fma(b, -math.inf, math.inf)
1463+
1464+
# Infinite result when a*b and c both infinite of the same sign.
1465+
for b in positives:
1466+
self.assertEqual(math.fma(math.inf, b, math.inf), math.inf)
1467+
self.assertEqual(math.fma(math.inf, -b, -math.inf), -math.inf)
1468+
self.assertEqual(math.fma(-math.inf, -b, math.inf), math.inf)
1469+
self.assertEqual(math.fma(-math.inf, b, -math.inf), -math.inf)
1470+
self.assertEqual(math.fma(b, math.inf, math.inf), math.inf)
1471+
self.assertEqual(math.fma(-b, math.inf, -math.inf), -math.inf)
1472+
self.assertEqual(math.fma(-b, -math.inf, math.inf), math.inf)
1473+
self.assertEqual(math.fma(b, -math.inf, -math.inf), -math.inf)
1474+
1475+
# Infinite result when a*b finite, c infinite.
1476+
for a, b in itertools.product(finites, finites):
1477+
self.assertEqual(math.fma(a, b, math.inf), math.inf)
1478+
self.assertEqual(math.fma(a, b, -math.inf), -math.inf)
1479+
1480+
# Infinite result when a*b infinite, c finite.
1481+
for b, c in itertools.product(positives, finites):
1482+
self.assertEqual(math.fma(math.inf, b, c), math.inf)
1483+
self.assertEqual(math.fma(-math.inf, b, c), -math.inf)
1484+
self.assertEqual(math.fma(-math.inf, -b, c), math.inf)
1485+
self.assertEqual(math.fma(math.inf, -b, c), -math.inf)
1486+
1487+
self.assertEqual(math.fma(b, math.inf, c), math.inf)
1488+
self.assertEqual(math.fma(b, -math.inf, c), -math.inf)
1489+
self.assertEqual(math.fma(-b, -math.inf, c), math.inf)
1490+
self.assertEqual(math.fma(-b, math.inf, c), -math.inf)
1491+
1492+
def test_fma_zero_result(self):
1493+
nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
1494+
1495+
# Zero results from exact zero inputs.
1496+
for b in nonnegative_finites:
1497+
self.assertIsPositiveZero(math.fma(0.0, b, 0.0))
1498+
self.assertIsPositiveZero(math.fma(0.0, b, -0.0))
1499+
self.assertIsNegativeZero(math.fma(0.0, -b, -0.0))
1500+
self.assertIsPositiveZero(math.fma(0.0, -b, 0.0))
1501+
self.assertIsPositiveZero(math.fma(-0.0, -b, 0.0))
1502+
self.assertIsPositiveZero(math.fma(-0.0, -b, -0.0))
1503+
self.assertIsNegativeZero(math.fma(-0.0, b, -0.0))
1504+
self.assertIsPositiveZero(math.fma(-0.0, b, 0.0))
1505+
1506+
self.assertIsPositiveZero(math.fma(b, 0.0, 0.0))
1507+
self.assertIsPositiveZero(math.fma(b, 0.0, -0.0))
1508+
self.assertIsNegativeZero(math.fma(-b, 0.0, -0.0))
1509+
self.assertIsPositiveZero(math.fma(-b, 0.0, 0.0))
1510+
self.assertIsPositiveZero(math.fma(-b, -0.0, 0.0))
1511+
self.assertIsPositiveZero(math.fma(-b, -0.0, -0.0))
1512+
self.assertIsNegativeZero(math.fma(b, -0.0, -0.0))
1513+
self.assertIsPositiveZero(math.fma(b, -0.0, 0.0))
1514+
1515+
# Exact zero result from nonzero inputs.
1516+
self.assertIsPositiveZero(math.fma(2.0, 2.0, -4.0))
1517+
self.assertIsPositiveZero(math.fma(2.0, -2.0, 4.0))
1518+
self.assertIsPositiveZero(math.fma(-2.0, -2.0, -4.0))
1519+
self.assertIsPositiveZero(math.fma(-2.0, 2.0, 4.0))
1520+
1521+
# Underflow to zero.
1522+
tiny = 1e-300
1523+
self.assertIsPositiveZero(math.fma(tiny, tiny, 0.0))
1524+
self.assertIsNegativeZero(math.fma(tiny, -tiny, 0.0))
1525+
self.assertIsPositiveZero(math.fma(-tiny, -tiny, 0.0))
1526+
self.assertIsNegativeZero(math.fma(-tiny, tiny, 0.0))
1527+
self.assertIsPositiveZero(math.fma(tiny, tiny, -0.0))
1528+
self.assertIsNegativeZero(math.fma(tiny, -tiny, -0.0))
1529+
self.assertIsPositiveZero(math.fma(-tiny, -tiny, -0.0))
1530+
self.assertIsNegativeZero(math.fma(-tiny, tiny, -0.0))
1531+
1532+
# Corner case where rounding the multiplication would
1533+
# give the wrong result.
1534+
x = float.fromhex('0x1p-500')
1535+
y = float.fromhex('0x1p-550')
1536+
z = float.fromhex('0x1p-1000')
1537+
self.assertIsNegativeZero(math.fma(x-y, x+y, -z))
1538+
self.assertIsPositiveZero(math.fma(y-x, x+y, z))
1539+
self.assertIsNegativeZero(math.fma(y-x, -(x+y), -z))
1540+
self.assertIsPositiveZero(math.fma(x-y, -(x+y), z))
1541+
1542+
def test_fma_overflow(self):
1543+
a = b = float.fromhex('0x1p512')
1544+
c = float.fromhex('0x1p1023')
1545+
# Overflow from multiplication.
1546+
with self.assertRaises(OverflowError):
1547+
math.fma(a, b, 0.0)
1548+
self.assertEqual(math.fma(a, b/2.0, 0.0), c)
1549+
# Overflow from the addition.
1550+
with self.assertRaises(OverflowError):
1551+
math.fma(a, b/2.0, c)
1552+
# No overflow, even though a*b overflows a float.
1553+
self.assertEqual(math.fma(a, b, -c), c)
1554+
1555+
# Extreme case: a * b is exactly at the overflow boundary, so the
1556+
# tiniest offset makes a difference between overflow and a finite
1557+
# result.
1558+
a = float.fromhex('0x1.ffffffc000000p+511')
1559+
b = float.fromhex('0x1.0000002000000p+512')
1560+
c = float.fromhex('0x0.0000000000001p-1022')
1561+
with self.assertRaises(OverflowError):
1562+
math.fma(a, b, 0.0)
1563+
with self.assertRaises(OverflowError):
1564+
math.fma(a, b, c)
1565+
self.assertEqual(math.fma(a, b, -c),
1566+
float.fromhex('0x1.fffffffffffffp+1023'))
1567+
1568+
# Another extreme case: here a*b is about as large as possible subject
1569+
# to math.fma(a, b, c) being finite.
1570+
a = float.fromhex('0x1.ae565943785f9p+512')
1571+
b = float.fromhex('0x1.3094665de9db8p+512')
1572+
c = float.fromhex('0x1.fffffffffffffp+1023')
1573+
self.assertEqual(math.fma(a, b, -c), c)
1574+
1575+
def test_fma_single_round(self):
1576+
a = float.fromhex('0x1p-50')
1577+
self.assertEqual(math.fma(a - 1.0, a + 1.0, 1.0), a*a)
1578+
1579+
def test_random(self):
1580+
# A collection of randomly generated inputs for which the naive FMA
1581+
# (with two rounds) gives a different result from a singly-rounded FMA.
1582+
1583+
# tuples (a, b, c, expected)
1584+
test_values = [
1585+
('0x1.694adde428b44p-1', '0x1.371b0d64caed7p-1',
1586+
'0x1.f347e7b8deab8p-4', '0x1.19f10da56c8adp-1'),
1587+
('0x1.605401ccc6ad6p-2', '0x1.ce3a40bf56640p-2',
1588+
'0x1.96e3bf7bf2e20p-2', '0x1.1af6d8aa83101p-1'),
1589+
('0x1.e5abd653a67d4p-2', '0x1.a2e400209b3e6p-1',
1590+
'0x1.a90051422ce13p-1', '0x1.37d68cc8c0fbbp+0'),
1591+
('0x1.f94e8efd54700p-2', '0x1.123065c812cebp-1',
1592+
'0x1.458f86fb6ccd0p-1', '0x1.ccdcee26a3ff3p-1'),
1593+
('0x1.bd926f1eedc96p-1', '0x1.eee9ca68c5740p-1',
1594+
'0x1.960c703eb3298p-2', '0x1.3cdcfb4fdb007p+0'),
1595+
('0x1.27348350fbccdp-1', '0x1.3b073914a53f1p-1',
1596+
'0x1.e300da5c2b4cbp-1', '0x1.4c51e9a3c4e29p+0'),
1597+
('0x1.2774f00b3497bp-1', '0x1.7038ec336bff0p-2',
1598+
'0x1.2f6f2ccc3576bp-1', '0x1.99ad9f9c2688bp-1'),
1599+
('0x1.51d5a99300e5cp-1', '0x1.5cd74abd445a1p-1',
1600+
'0x1.8880ab0bbe530p-1', '0x1.3756f96b91129p+0'),
1601+
('0x1.73cb965b821b8p-2', '0x1.218fd3d8d5371p-1',
1602+
'0x1.d1ea966a1f758p-2', '0x1.5217b8fd90119p-1'),
1603+
('0x1.4aa98e890b046p-1', '0x1.954d85dff1041p-1',
1604+
'0x1.122b59317ebdfp-1', '0x1.0bf644b340cc5p+0'),
1605+
('0x1.e28f29e44750fp-1', '0x1.4bcc4fdcd18fep-1',
1606+
'0x1.fd47f81298259p-1', '0x1.9b000afbc9995p+0'),
1607+
('0x1.d2e850717fe78p-3', '0x1.1dd7531c303afp-1',
1608+
'0x1.e0869746a2fc2p-2', '0x1.316df6eb26439p-1'),
1609+
('0x1.cf89c75ee6fbap-2', '0x1.b23decdc66825p-1',
1610+
'0x1.3d1fe76ac6168p-1', '0x1.00d8ea4c12abbp+0'),
1611+
('0x1.3265ae6f05572p-2', '0x1.16d7ec285f7a2p-1',
1612+
'0x1.0b8405b3827fbp-1', '0x1.5ef33c118a001p-1'),
1613+
('0x1.c4d1bf55ec1a5p-1', '0x1.bc59618459e12p-2',
1614+
'0x1.ce5b73dc1773dp-1', '0x1.496cf6164f99bp+0'),
1615+
('0x1.d350026ac3946p-1', '0x1.9a234e149a68cp-2',
1616+
'0x1.f5467b1911fd6p-2', '0x1.b5cee3225caa5p-1'),
1617+
]
1618+
for a_hex, b_hex, c_hex, expected_hex in test_values:
1619+
a = float.fromhex(a_hex)
1620+
b = float.fromhex(b_hex)
1621+
c = float.fromhex(c_hex)
1622+
expected = float.fromhex(expected_hex)
1623+
self.assertEqual(math.fma(a, b, c), expected)
1624+
self.assertEqual(math.fma(b, a, c), expected)
1625+
1626+
# Custom assertions.
1627+
def assertIsNaN(self, value):
1628+
self.assertTrue(
1629+
math.isnan(value),
1630+
msg="Expected a NaN, got {!r}".format(value)
1631+
)
1632+
1633+
def assertIsPositiveZero(self, value):
1634+
self.assertTrue(
1635+
value == 0 and math.copysign(1, value) > 0,
1636+
msg="Expected a positive zero, got {!r}".format(value)
1637+
)
1638+
1639+
def assertIsNegativeZero(self, value):
1640+
self.assertTrue(
1641+
value == 0 and math.copysign(1, value) < 0,
1642+
msg="Expected a negative zero, got {!r}".format(value)
1643+
)
1644+
1645+
14131646
def test_main():
14141647
from doctest import DocFileSuite
14151648
suite = unittest.TestSuite()
14161649
suite.addTest(unittest.makeSuite(MathTests))
14171650
suite.addTest(unittest.makeSuite(IsCloseTests))
1651+
suite.addTest(unittest.makeSuite(FMATests))
14181652
suite.addTest(DocFileSuite("ieee754.txt"))
14191653
run_unittest(suite)
14201654

Misc/NEWS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ Core and Builtins
215215
Library
216216
-------
217217

218+
- Issue #29282: Added new math.fma function, wrapping C99's fma
219+
operation.
220+
218221
- Issue #29197: Removed deprecated function ntpath.splitunc().
219222

220223
- Issue #29210: Removed support of deprecated argument "exclude" in

Modules/clinic/mathmodule.c.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,40 @@ PyDoc_STRVAR(math_factorial__doc__,
8080
#define MATH_FACTORIAL_METHODDEF \
8181
{"factorial", (PyCFunction)math_factorial, METH_O, math_factorial__doc__},
8282

83+
PyDoc_STRVAR(math_fma__doc__,
84+
"fma($module, x, y, z, /)\n"
85+
"--\n"
86+
"\n"
87+
"Fused multiply-add operation. Compute (x * y) + z with a single round.");
88+
89+
#define MATH_FMA_METHODDEF \
90+
{"fma", (PyCFunction)math_fma, METH_FASTCALL, math_fma__doc__},
91+
92+
static PyObject *
93+
math_fma_impl(PyObject *module, double x, double y, double z);
94+
95+
static PyObject *
96+
math_fma(PyObject *module, PyObject **args, Py_ssize_t nargs, PyObject *kwnames)
97+
{
98+
PyObject *return_value = NULL;
99+
double x;
100+
double y;
101+
double z;
102+
103+
if (!_PyArg_ParseStack(args, nargs, "ddd:fma",
104+
&x, &y, &z)) {
105+
goto exit;
106+
}
107+
108+
if (!_PyArg_NoStackKeywords("fma", kwnames)) {
109+
goto exit;
110+
}
111+
return_value = math_fma_impl(module, x, y, z);
112+
113+
exit:
114+
return return_value;
115+
}
116+
83117
PyDoc_STRVAR(math_trunc__doc__,
84118
"trunc($module, x, /)\n"
85119
"--\n"
@@ -536,4 +570,4 @@ math_isclose(PyObject *module, PyObject **args, Py_ssize_t nargs, PyObject *kwna
536570
exit:
537571
return return_value;
538572
}
539-
/*[clinic end generated code: output=71806f73a5c4bf0b input=a9049054013a1b77]*/
573+
/*[clinic end generated code: output=f428e1075d00c334 input=a9049054013a1b77]*/

0 commit comments

Comments
 (0)