You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

314 lines
10 KiB

"""
Plotting (requires matplotlib)
"""
from colorsys import hsv_to_rgb, hls_to_rgb
from .libmp import NoConvergence
from .libmp.backend import xrange
class VisualizationMethods(object):
plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError, NoConvergence)
def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
singularities=[], axes=None):
r"""
Shows a simple 2D plot of a function `f(x)` or list of functions
`[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
specified by *xlim*. Some examples::
plot(lambda x: exp(x)*li(x), [1, 4])
plot([cos, sin], [-4, 4])
plot([fresnels, fresnelc], [-4, 4])
plot([sqrt, cbrt], [-4, 4])
plot(lambda t: zeta(0.5+t*j), [-20, 20])
plot([floor, ceil, abs, sign], [-5, 5])
Points where the function raises a numerical exception or
returns an infinite value are removed from the graph.
Singularities can also be excluded explicitly
as follows (useful for removing erroneous vertical lines)::
plot(cot, ylim=[-5, 5]) # bad
plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good
For parts where the function assumes complex values, the
real part is plotted with dashes and the imaginary part
is plotted with dots.
.. note :: This function requires matplotlib (pylab).
"""
if file:
axes = None
fig = None
if not axes:
import pylab
fig = pylab.figure()
axes = fig.add_subplot(111)
if not isinstance(f, (tuple, list)):
f = [f]
a, b = xlim
colors = ['b', 'r', 'g', 'm', 'k']
for n, func in enumerate(f):
x = ctx.arange(a, b, (b-a)/float(points))
segments = []
segment = []
in_complex = False
for i in xrange(len(x)):
try:
if i != 0:
for sing in singularities:
if x[i-1] <= sing and x[i] >= sing:
raise ValueError
v = func(x[i])
if ctx.isnan(v) or abs(v) > 1e300:
raise ValueError
if hasattr(v, "imag") and v.imag:
re = float(v.real)
im = float(v.imag)
if not in_complex:
in_complex = True
segments.append(segment)
segment = []
segment.append((float(x[i]), re, im))
else:
if in_complex:
in_complex = False
segments.append(segment)
segment = []
if hasattr(v, "real"):
v = v.real
segment.append((float(x[i]), v))
except ctx.plot_ignore:
if segment:
segments.append(segment)
segment = []
if segment:
segments.append(segment)
for segment in segments:
x = [s[0] for s in segment]
y = [s[1] for s in segment]
if not x:
continue
c = colors[n % len(colors)]
if len(segment[0]) == 3:
z = [s[2] for s in segment]
axes.plot(x, y, '--'+c, linewidth=3)
axes.plot(x, z, ':'+c, linewidth=3)
else:
axes.plot(x, y, c, linewidth=3)
axes.set_xlim([float(_) for _ in xlim])
if ylim:
axes.set_ylim([float(_) for _ in ylim])
axes.set_xlabel('x')
axes.set_ylabel('f(x)')
axes.grid(True)
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
def default_color_function(ctx, z):
if ctx.isinf(z):
return (1.0, 1.0, 1.0)
if ctx.isnan(z):
return (0.5, 0.5, 0.5)
pi = 3.1415926535898
a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
a = (a + 0.5) % 1.0
b = 1.0 - float(1/(1.0+abs(z)**0.3))
return hls_to_rgb(a, b, 0.8)
blue_orange_colors = [
(-1.0, (0.0, 0.0, 0.0)),
(-0.95, (0.1, 0.2, 0.5)), # dark blue
(-0.5, (0.0, 0.5, 1.0)), # blueish
(-0.05, (0.4, 0.8, 0.8)), # cyanish
( 0.0, (1.0, 1.0, 1.0)),
( 0.05, (1.0, 0.9, 0.3)), # yellowish
( 0.5, (0.9, 0.5, 0.0)), # orangeish
( 0.95, (0.7, 0.1, 0.0)), # redish
( 1.0, (0.0, 0.0, 0.0)),
( 2.0, (0.0, 0.0, 0.0)),
]
def phase_color_function(ctx, z):
if ctx.isinf(z):
return (1.0, 1.0, 1.0)
if ctx.isnan(z):
return (0.5, 0.5, 0.5)
pi = 3.1415926535898
w = float(ctx.arg(z)) / pi
w = max(min(w, 1.0), -1.0)
for i in range(1,len(blue_orange_colors)):
if blue_orange_colors[i][0] > w:
a, (ra, ga, ba) = blue_orange_colors[i-1]
b, (rb, gb, bb) = blue_orange_colors[i]
s = (w-a) / (b-a)
return ra+(rb-ra)*s, ga+(gb-ga)*s, ba+(bb-ba)*s
def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
verbose=False, file=None, dpi=None, axes=None):
"""
Plots the given complex-valued function *f* over a rectangular part
of the complex plane specified by the pairs of intervals *re* and *im*.
For example::
cplot(lambda z: z, [-2, 2], [-10, 10])
cplot(exp)
cplot(zeta, [0, 1], [0, 50])
By default, the complex argument (phase) is shown as color (hue) and
the magnitude is show as brightness. You can also supply a
custom color function (*color*). This function should take a
complex number as input and return an RGB 3-tuple containing
floats in the range 0.0-1.0.
Alternatively, you can select a builtin color function by passing
a string as *color*:
* "default" - default color scheme
* "phase" - a color scheme that only renders the phase of the function,
with white for positive reals, black for negative reals, gold in the
upper half plane, and blue in the lower half plane.
To obtain a sharp image, the number of points may need to be
increased to 100,000 or thereabout. Since evaluating the
function that many times is likely to be slow, the 'verbose'
option is useful to display progress.
.. note :: This function requires matplotlib (pylab).
"""
if color is None or color == "default":
color = ctx.default_color_function
if color == "phase":
color = ctx.phase_color_function
import pylab
if file:
axes = None
fig = None
if not axes:
fig = pylab.figure()
axes = fig.add_subplot(111)
rea, reb = re
ima, imb = im
dre = reb - rea
dim = imb - ima
M = int(ctx.sqrt(points*dre/dim)+1)
N = int(ctx.sqrt(points*dim/dre)+1)
x = pylab.linspace(rea, reb, M)
y = pylab.linspace(ima, imb, N)
# Note: we have to be careful to get the right rotation.
# Test with these plots:
# cplot(lambda z: z if z.real < 0 else 0)
# cplot(lambda z: z if z.imag < 0 else 0)
w = pylab.zeros((N, M, 3))
for n in xrange(N):
for m in xrange(M):
z = ctx.mpc(x[m], y[n])
try:
v = color(f(z))
except ctx.plot_ignore:
v = (0.5, 0.5, 0.5)
w[n,m] = v
if verbose:
print(str(n) + ' of ' + str(N))
rea, reb, ima, imb = [float(_) for _ in [rea, reb, ima, imb]]
axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
axes.set_xlabel('Re(z)')
axes.set_ylabel('Im(z)')
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
wireframe=False, file=None, dpi=None, axes=None):
"""
Plots the surface defined by `f`.
If `f` returns a single component, then this plots the surface
defined by `z = f(x,y)` over the rectangular domain with
`x = u` and `y = v`.
If `f` returns three components, then this plots the parametric
surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
For example, to plot a simple function::
>>> from mpmath import *
>>> f = lambda x, y: sin(x+y)*cos(y)
>>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP
Plotting a donut::
>>> r, R = 1, 2.5
>>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
>>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP
.. note :: This function requires matplotlib (pylab) 0.98.5.3 or higher.
"""
import pylab
import mpl_toolkits.mplot3d as mplot3d
if file:
axes = None
fig = None
if not axes:
fig = pylab.figure()
axes = mplot3d.axes3d.Axes3D(fig)
ua, ub = u
va, vb = v
du = ub - ua
dv = vb - va
if not isinstance(points, (list, tuple)):
points = [points, points]
M, N = points
u = pylab.linspace(ua, ub, M)
v = pylab.linspace(va, vb, N)
x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
xab, yab, zab = [[0, 0] for i in xrange(3)]
for n in xrange(N):
for m in xrange(M):
fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
try:
x[m,n], y[m,n], z[m,n] = fdata
except TypeError:
x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
if c < cab[0]:
cab[0] = c
if c > cab[1]:
cab[1] = c
if wireframe:
axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
else:
axes.plot_surface(x, y, z, rstride=4, cstride=4)
axes.set_xlabel('x')
axes.set_ylabel('y')
axes.set_zlabel('z')
if keep_aspect:
dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
maxd = max(dx, dy, dz)
if dx < maxd:
delta = maxd - dx
axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
if dy < maxd:
delta = maxd - dy
axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
if dz < maxd:
delta = maxd - dz
axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
VisualizationMethods.plot = plot
VisualizationMethods.default_color_function = default_color_function
VisualizationMethods.phase_color_function = phase_color_function
VisualizationMethods.cplot = cplot
VisualizationMethods.splot = splot