You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.1 KiB

from sympy.core.basic import Basic
from sympy.printing import pprint
import random
def interactive_traversal(expr):
"""Traverse a tree asking a user which branch to choose. """
RED, BRED = '\033[0;31m', '\033[1;31m'
GREEN, BGREEN = '\033[0;32m', '\033[1;32m'
YELLOW, BYELLOW = '\033[0;33m', '\033[1;33m' # noqa
BLUE, BBLUE = '\033[0;34m', '\033[1;34m' # noqa
MAGENTA, BMAGENTA = '\033[0;35m', '\033[1;35m'# noqa
CYAN, BCYAN = '\033[0;36m', '\033[1;36m' # noqa
END = '\033[0m'
def cprint(*args):
print("".join(map(str, args)) + END)
def _interactive_traversal(expr, stage):
if stage > 0:
print()
cprint("Current expression (stage ", BYELLOW, stage, END, "):")
print(BCYAN)
pprint(expr)
print(END)
if isinstance(expr, Basic):
if expr.is_Add:
args = expr.as_ordered_terms()
elif expr.is_Mul:
args = expr.as_ordered_factors()
else:
args = expr.args
elif hasattr(expr, "__iter__"):
args = list(expr)
else:
return expr
n_args = len(args)
if not n_args:
return expr
for i, arg in enumerate(args):
cprint(GREEN, "[", BGREEN, i, GREEN, "] ", BLUE, type(arg), END)
pprint(arg)
print()
if n_args == 1:
choices = '0'
else:
choices = '0-%d' % (n_args - 1)
try:
choice = input("Your choice [%s,f,l,r,d,?]: " % choices)
except EOFError:
result = expr
print()
else:
if choice == '?':
cprint(RED, "%s - select subexpression with the given index" %
choices)
cprint(RED, "f - select the first subexpression")
cprint(RED, "l - select the last subexpression")
cprint(RED, "r - select a random subexpression")
cprint(RED, "d - done\n")
result = _interactive_traversal(expr, stage)
elif choice in ('d', ''):
result = expr
elif choice == 'f':
result = _interactive_traversal(args[0], stage + 1)
elif choice == 'l':
result = _interactive_traversal(args[-1], stage + 1)
elif choice == 'r':
result = _interactive_traversal(random.choice(args), stage + 1)
else:
try:
choice = int(choice)
except ValueError:
cprint(BRED,
"Choice must be a number in %s range\n" % choices)
result = _interactive_traversal(expr, stage)
else:
if choice < 0 or choice >= n_args:
cprint(BRED, "Choice must be in %s range\n" % choices)
result = _interactive_traversal(expr, stage)
else:
result = _interactive_traversal(args[choice], stage + 1)
return result
return _interactive_traversal(expr, 0)