You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
7.7 KiB

5 months ago
from sympy.external.importtools import version_tuple
from collections.abc import Iterable
from sympy.core.mul import Mul
from sympy.core.singleton import S
from sympy.codegen.cfunctions import Sqrt
from sympy.external import import_module
from sympy.printing.precedence import PRECEDENCE
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter
import sympy
tensorflow = import_module('tensorflow')
class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter):
"""
Tensorflow printer which handles vectorized piecewise functions,
logical operators, max/min, and relational operators.
"""
printmethod = "_tensorflowcode"
mapping = {
sympy.Abs: "tensorflow.math.abs",
sympy.sign: "tensorflow.math.sign",
# XXX May raise error for ints.
sympy.ceiling: "tensorflow.math.ceil",
sympy.floor: "tensorflow.math.floor",
sympy.log: "tensorflow.math.log",
sympy.exp: "tensorflow.math.exp",
Sqrt: "tensorflow.math.sqrt",
sympy.cos: "tensorflow.math.cos",
sympy.acos: "tensorflow.math.acos",
sympy.sin: "tensorflow.math.sin",
sympy.asin: "tensorflow.math.asin",
sympy.tan: "tensorflow.math.tan",
sympy.atan: "tensorflow.math.atan",
sympy.atan2: "tensorflow.math.atan2",
# XXX Also may give NaN for complex results.
sympy.cosh: "tensorflow.math.cosh",
sympy.acosh: "tensorflow.math.acosh",
sympy.sinh: "tensorflow.math.sinh",
sympy.asinh: "tensorflow.math.asinh",
sympy.tanh: "tensorflow.math.tanh",
sympy.atanh: "tensorflow.math.atanh",
sympy.re: "tensorflow.math.real",
sympy.im: "tensorflow.math.imag",
sympy.arg: "tensorflow.math.angle",
# XXX May raise error for ints and complexes
sympy.erf: "tensorflow.math.erf",
sympy.loggamma: "tensorflow.math.lgamma",
sympy.Eq: "tensorflow.math.equal",
sympy.Ne: "tensorflow.math.not_equal",
sympy.StrictGreaterThan: "tensorflow.math.greater",
sympy.StrictLessThan: "tensorflow.math.less",
sympy.LessThan: "tensorflow.math.less_equal",
sympy.GreaterThan: "tensorflow.math.greater_equal",
sympy.And: "tensorflow.math.logical_and",
sympy.Or: "tensorflow.math.logical_or",
sympy.Not: "tensorflow.math.logical_not",
sympy.Max: "tensorflow.math.maximum",
sympy.Min: "tensorflow.math.minimum",
# Matrices
sympy.MatAdd: "tensorflow.math.add",
sympy.HadamardProduct: "tensorflow.math.multiply",
sympy.Trace: "tensorflow.linalg.trace",
# XXX May raise error for integer matrices.
sympy.Determinant : "tensorflow.linalg.det",
}
_default_settings = dict(
AbstractPythonCodePrinter._default_settings,
tensorflow_version=None
)
def __init__(self, settings=None):
super().__init__(settings)
version = self._settings['tensorflow_version']
if version is None and tensorflow:
version = tensorflow.__version__
self.tensorflow_version = version
def _print_Function(self, expr):
op = self.mapping.get(type(expr), None)
if op is None:
return super()._print_Basic(expr)
children = [self._print(arg) for arg in expr.args]
if len(children) == 1:
return "%s(%s)" % (
self._module_format(op),
children[0]
)
else:
return self._expand_fold_binary_op(op, children)
_print_Expr = _print_Function
_print_Application = _print_Function
_print_MatrixExpr = _print_Function
# TODO: a better class structure would avoid this mess:
_print_Relational = _print_Function
_print_Not = _print_Function
_print_And = _print_Function
_print_Or = _print_Function
_print_HadamardProduct = _print_Function
_print_Trace = _print_Function
_print_Determinant = _print_Function
def _print_Inverse(self, expr):
op = self._module_format('tensorflow.linalg.inv')
return "{}({})".format(op, self._print(expr.arg))
def _print_Transpose(self, expr):
version = self.tensorflow_version
if version and version_tuple(version) < version_tuple('1.14'):
op = self._module_format('tensorflow.matrix_transpose')
else:
op = self._module_format('tensorflow.linalg.matrix_transpose')
return "{}({})".format(op, self._print(expr.arg))
def _print_Derivative(self, expr):
variables = expr.variables
if any(isinstance(i, Iterable) for i in variables):
raise NotImplementedError("derivation by multiple variables is not supported")
def unfold(expr, args):
if not args:
return self._print(expr)
return "%s(%s, %s)[0]" % (
self._module_format("tensorflow.gradients"),
unfold(expr, args[:-1]),
self._print(args[-1]),
)
return unfold(expr.expr, variables)
def _print_Piecewise(self, expr):
version = self.tensorflow_version
if version and version_tuple(version) < version_tuple('1.0'):
tensorflow_piecewise = "tensorflow.select"
else:
tensorflow_piecewise = "tensorflow.where"
from sympy.functions.elementary.piecewise import Piecewise
e, cond = expr.args[0].args
if len(expr.args) == 1:
return '{}({}, {}, {})'.format(
self._module_format(tensorflow_piecewise),
self._print(cond),
self._print(e),
0)
return '{}({}, {}, {})'.format(
self._module_format(tensorflow_piecewise),
self._print(cond),
self._print(e),
self._print(Piecewise(*expr.args[1:])))
def _print_Pow(self, expr):
# XXX May raise error for
# int**float or int**complex or float**complex
base, exp = expr.args
if expr.exp == S.Half:
return "{}({})".format(
self._module_format("tensorflow.math.sqrt"), self._print(base))
return "{}({}, {})".format(
self._module_format("tensorflow.math.pow"),
self._print(base), self._print(exp))
def _print_MatrixBase(self, expr):
tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant"
data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]"
return "%s(%s)" % (
self._module_format(tensorflow_f),
data,
)
def _print_MatMul(self, expr):
from sympy.matrices.expressions import MatrixExpr
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
args = [arg for arg in expr.args if arg not in mat_args]
if args:
return "%s*%s" % (
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
self._expand_fold_binary_op(
"tensorflow.linalg.matmul", mat_args)
)
else:
return self._expand_fold_binary_op(
"tensorflow.linalg.matmul", mat_args)
def _print_MatPow(self, expr):
return self._expand_fold_binary_op(
"tensorflow.linalg.matmul", [expr.base]*expr.exp)
def _print_CodeBlock(self, expr):
# TODO: is this necessary?
ret = []
for subexpr in expr.args:
ret.append(self._print(subexpr))
return "\n".join(ret)
_module = "tensorflow"
_einsum = "linalg.einsum"
_add = "math.add"
_transpose = "transpose"
_ones = "ones"
_zeros = "zeros"
def tensorflow_code(expr, **settings):
printer = TensorflowPrinter(settings)
return printer.doprint(expr)