Add math.prod()

pull/392/head
A. R. Shajii 2023-05-14 12:36:42 -04:00
parent 136a719558
commit 5085dae04d
2 changed files with 112 additions and 0 deletions

View File

@ -679,6 +679,28 @@ def fsum(seq):
return hi
def prod(iterable, start = 1):
def prod_generator(iterable: Generator[T], start, T: type):
if T is float:
result = float(start)
else:
result = start
for a in iterable:
result *= a
return result
def prod_tuple(iterable, start):
if staticlen(iterable) == 0:
return start
else:
return prod(iterable[1:], start=(start * iterable[0]))
if isinstance(iterable, Tuple):
return prod_tuple(iterable, start)
else:
return prod_generator(iterable, start)
# 32-bit float ops
e32 = float32(e)

View File

@ -585,6 +585,95 @@ def test_fsum():
assert msum(vals) == math.fsum(vals)
@test
def test_prod():
is_nan = lambda x: math.isnan(x)
prod = math.prod
assert prod(()) == 1
assert prod((), start=5) == 5
assert prod(list(range(2,8))) == 5040
assert prod(iter(list(range(2,8)))) == 5040
assert prod(range(1, 10), start=10) == 3628800
assert prod([1, 2, 3, 4, 5]) == 120
assert prod([1.0, 2.0, 3.0, 4.0, 5.0]) == 120.0
assert prod([1, 2, 3, 4.0, 5.0]) == 120.0
assert prod([1.0, 2.0, 3.0, 4, 5]) == 120.0
# Test overflow in fast-path for integers
assert prod([1, 1, 2**32, 1, 1]) == 2**32
# Test overflow in fast-path for floats
assert prod([1.0, 1.0, 2**32, 1, 1]) == float(2**32)
# Some odd cases
assert prod([2, 3], start='ab') == 'abababababab'
assert prod([2, 3], start=[1, 2]) == [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
assert prod((), start={2: 3}) == {2:3}
#with self.assertRaises(TypeError):
# prod([10, 20], 1) # start is a keyword-only argument
assert prod([0, 1, 2, 3]) == 0
assert prod([1, 0, 2, 3]) == 0
assert prod([1, 2, 3, 0]) == 0
def _naive_prod(iterable, start=1):
for elem in iterable:
start *= elem
return start
iterable = range(1, 13)
assert prod(iterable) == _naive_prod(iterable)
iterable = range(-12, -1)
assert prod(iterable) == _naive_prod(iterable)
iterable = range(-11, 5)
assert prod(iterable) == 0
# Big floats
iterable = [float(x) for x in range(1, 123)]
assert prod(iterable) == _naive_prod(iterable, 1.0)
iterable = [float(x) for x in range(-123, -1)]
assert prod(iterable) == _naive_prod(iterable, 1.0)
iterable = [float(x) for x in range(-1000, 1000)]
assert is_nan(prod(iterable))
# Float tests
assert is_nan(prod([1, 2, 3, float("nan"), 2, 3]))
assert is_nan(prod([1, 0, float("nan"), 2, 3]))
assert is_nan(prod([1, float("nan"), 0, 3]))
assert is_nan(prod([1, float("inf"), float("nan"),3]))
assert is_nan(prod([1, float("-inf"), float("nan"),3]))
assert is_nan(prod([1, float("nan"), float("inf"),3]))
assert is_nan(prod([1, float("nan"), float("-inf"),3]))
assert prod([1, 2, 3, float('inf'),-3,4]) == float('-inf')
assert prod([1, 2, 3, float('-inf'),-3,4]) == float('inf')
assert is_nan(prod([1,2,0,float('inf'), -3, 4]))
assert is_nan(prod([1,2,0,float('-inf'), -3, 4]))
assert is_nan(prod([1, 2, 3, float('inf'), -3, 0, 3]))
assert is_nan(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
# Type preservation
assert type(prod([1, 2, 3, 4, 5, 6])) is int
assert type(prod([1, 2.0, 3, 4, 5, 6])) is float
assert type(prod(range(1, 10000))) is int
assert type(prod(range(1, 10000), start=1.0)) is float
#assert type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])) is decimal.Decimal
# Tuples
assert prod((42,)) == 42
assert prod((-3.5, 4.0)) == -14.0
assert prod((3, 9, 3, 7, 1)) == 567
assert prod((1, 2.5, 3)) == 7.5
assert prod((2.5, 1, 3, 1.0, 1)) == 7.5
assert prod((2, 3), start='ab') == 'abababababab'
test_isnan()
test_isinf()
test_isfinite()
@ -629,6 +718,7 @@ test_frexp()
test_modf()
test_isclose()
test_fsum()
test_prod()
# 32-bit float ops