diff --git a/stdlib/math.codon b/stdlib/math.codon index 61780dd5..9019d7f5 100644 --- a/stdlib/math.codon +++ b/stdlib/math.codon @@ -679,6 +679,28 @@ def fsum(seq): return hi +def prod(iterable, start = 1): + def prod_generator(iterable: Generator[T], start, T: type): + if T is float: + result = float(start) + else: + result = start + + for a in iterable: + result *= a + return result + + def prod_tuple(iterable, start): + if staticlen(iterable) == 0: + return start + else: + return prod(iterable[1:], start=(start * iterable[0])) + + if isinstance(iterable, Tuple): + return prod_tuple(iterable, start) + else: + return prod_generator(iterable, start) + # 32-bit float ops e32 = float32(e) diff --git a/test/stdlib/math_test.codon b/test/stdlib/math_test.codon index c27d8f06..056d978f 100644 --- a/test/stdlib/math_test.codon +++ b/test/stdlib/math_test.codon @@ -585,6 +585,95 @@ def test_fsum(): assert msum(vals) == math.fsum(vals) +@test +def test_prod(): + is_nan = lambda x: math.isnan(x) + prod = math.prod + assert prod(()) == 1 + assert prod((), start=5) == 5 + assert prod(list(range(2,8))) == 5040 + assert prod(iter(list(range(2,8)))) == 5040 + assert prod(range(1, 10), start=10) == 3628800 + + assert prod([1, 2, 3, 4, 5]) == 120 + assert prod([1.0, 2.0, 3.0, 4.0, 5.0]) == 120.0 + assert prod([1, 2, 3, 4.0, 5.0]) == 120.0 + assert prod([1.0, 2.0, 3.0, 4, 5]) == 120.0 + + # Test overflow in fast-path for integers + assert prod([1, 1, 2**32, 1, 1]) == 2**32 + # Test overflow in fast-path for floats + assert prod([1.0, 1.0, 2**32, 1, 1]) == float(2**32) + + # Some odd cases + assert prod([2, 3], start='ab') == 'abababababab' + assert prod([2, 3], start=[1, 2]) == [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2] + assert prod((), start={2: 3}) == {2:3} + + #with self.assertRaises(TypeError): + # prod([10, 20], 1) # start is a keyword-only argument + + assert prod([0, 1, 2, 3]) == 0 + assert prod([1, 0, 2, 3]) == 0 + assert prod([1, 2, 3, 0]) == 0 + + def _naive_prod(iterable, start=1): + for elem in iterable: + start *= elem + return start + + iterable = range(1, 13) + assert prod(iterable) == _naive_prod(iterable) + iterable = range(-12, -1) + assert prod(iterable) == _naive_prod(iterable) + iterable = range(-11, 5) + assert prod(iterable) == 0 + + # Big floats + + iterable = [float(x) for x in range(1, 123)] + assert prod(iterable) == _naive_prod(iterable, 1.0) + iterable = [float(x) for x in range(-123, -1)] + assert prod(iterable) == _naive_prod(iterable, 1.0) + iterable = [float(x) for x in range(-1000, 1000)] + assert is_nan(prod(iterable)) + + # Float tests + + assert is_nan(prod([1, 2, 3, float("nan"), 2, 3])) + assert is_nan(prod([1, 0, float("nan"), 2, 3])) + assert is_nan(prod([1, float("nan"), 0, 3])) + assert is_nan(prod([1, float("inf"), float("nan"),3])) + assert is_nan(prod([1, float("-inf"), float("nan"),3])) + assert is_nan(prod([1, float("nan"), float("inf"),3])) + assert is_nan(prod([1, float("nan"), float("-inf"),3])) + + assert prod([1, 2, 3, float('inf'),-3,4]) == float('-inf') + assert prod([1, 2, 3, float('-inf'),-3,4]) == float('inf') + + assert is_nan(prod([1,2,0,float('inf'), -3, 4])) + assert is_nan(prod([1,2,0,float('-inf'), -3, 4])) + assert is_nan(prod([1, 2, 3, float('inf'), -3, 0, 3])) + assert is_nan(prod([1, 2, 3, float('-inf'), -3, 0, 2])) + + # Type preservation + + assert type(prod([1, 2, 3, 4, 5, 6])) is int + assert type(prod([1, 2.0, 3, 4, 5, 6])) is float + assert type(prod(range(1, 10000))) is int + assert type(prod(range(1, 10000), start=1.0)) is float + #assert type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])) is decimal.Decimal + + # Tuples + + assert prod((42,)) == 42 + assert prod((-3.5, 4.0)) == -14.0 + assert prod((3, 9, 3, 7, 1)) == 567 + assert prod((1, 2.5, 3)) == 7.5 + assert prod((2.5, 1, 3, 1.0, 1)) == 7.5 + assert prod((2, 3), start='ab') == 'abababababab' + + test_isnan() test_isinf() test_isfinite() @@ -629,6 +718,7 @@ test_frexp() test_modf() test_isclose() test_fsum() +test_prod() # 32-bit float ops