You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
1.8 KiB

from sympy.concrete.summations import Sum
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.utilities.lambdify import lambdify
from sympy.abc import x, i, a, b
from sympy.codegen.numpy_nodes import logaddexp
from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions
from sympy.testing.pytest import skip
from sympy.external import import_module
cp = import_module('cupy')
def test_cupy_print():
prntr = CuPyPrinter()
assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)'
assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)'
assert prntr.doprint(log(x)) == 'cupy.log(x)'
assert prntr.doprint("acos(x)") == 'cupy.arccos(x)'
assert prntr.doprint("exp(x)") == 'cupy.exp(x)'
assert prntr.doprint("Abs(x)") == 'abs(x)'
def test_not_cupy_print():
prntr = CuPyPrinter()
assert "Not supported" in prntr.doprint("abcd(x)")
def test_cupy_sum():
if not cp:
skip("CuPy not installed")
s = Sum(x ** i, (i, a, b))
f = lambdify((a, b, x), s, 'cupy')
a_, b_ = 0, 10
x_ = cp.linspace(-1, +1, 10)
assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
s = Sum(i * x, (i, a, b))
f = lambdify((a, b, x), s, 'numpy')
a_, b_ = 0, 10
x_ = cp.linspace(-1, +1, 10)
assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
def test_cupy_known_funcs_consts():
assert _cupy_known_constants['NaN'] == 'cupy.nan'
assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma'
assert _cupy_known_functions['acos'] == 'cupy.arccos'
assert _cupy_known_functions['log'] == 'cupy.log'
def test_cupy_print_methods():
prntr = CuPyPrinter()
assert hasattr(prntr, '_print_acos')
assert hasattr(prntr, '_print_log')