VisualizationMethods

Plotting (requires matplotlib)

"""

from colorsys import hsv_to_rgb, hls_to_rgb

from .libmp import NoConvergence

from .libmp.backend import xrange

class VisualizationMethods(object):

    plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError, NoConvergence)

def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,

    singularities=[], axes=None):

    r"""

    Shows a simple 2D plot of a function `f(x)` or list of functions

    `[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval

    specified by *xlim*. Some examples::

        plot(lambda x: exp(x)*li(x), [1, 4])

        plot([cos, sin], [-4, 4])

        plot([fresnels, fresnelc], [-4, 4])

        plot([sqrt, cbrt], [-4, 4])

        plot(lambda t: zeta(0.5+t*j), [-20, 20])

        plot([floor, ceil, abs, sign], [-5, 5])

    Points where the function raises a numerical exception or

    returns an infinite value are removed from the graph.

    Singularities can also be excluded explicitly

    as follows (useful for removing erroneous vertical lines)::

        plot(cot, ylim=[-5, 5])  # bad

        plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi])  # good

    For parts where the function assumes complex values, the

    real part is plotted with dashes and the imaginary part

    is plotted with dots.

    .. note :: This function requires matplotlib (pylab).

    """

    if file:

        axes = None

    fig = None

    if not axes:

        import pylab

        fig = pylab.figure()

        axes = fig.add_subplot(111)

    if not isinstance(f, (tuple, list)):

        f = [f]

    a, b = xlim

    colors = ['b', 'r', 'g', 'm', 'k']

    for n, func in enumerate(f):

        x = ctx.arange(a, b, (b-a)/float(points))

        segments = []

        segment = []

        in_complex = False

        for i in xrange(len(x)):

            try:

                if i != 0:

                    for sing in singularities:

                        if x[i-1] <= sing and x[i] >= sing:

                            raise ValueError

                v = func(x[i])

                if ctx.isnan(v) or abs(v) > 1e300:

                    raise ValueError

                if hasattr(v, "imag") and v.imag:

                    re = float(v.real)

                    im = float(v.imag)

                    if not in_complex:

                        in_complex = True

                        segments.append(segment)

                        segment = []

                    segment.append((float(x[i]), re, im))

                else:

                    if in_complex:

                        in_complex = False

                        segments.append(segment)

                        segment = []

                    if hasattr(v, "real"):

                        v = v.real

                    segment.append((float(x[i]), v))

            except ctx.plot_ignore:

                if segment:

                    segments.append(segment)

                segment = []

        if segment:

            segments.append(segment)

        for segment in segments:

            x = [s[0] for s in segment]

            y = [s[1] for s in segment]

            if not x:

                continue

            c = colors[n % len(colors)]

            if len(segment[0]) == 3:

                z = [s[2] for s in segment]

                axes.plot(x, y, '--'+c, linewidth=3)

                axes.plot(x, z, ':'+c, linewidth=3)

            else:

                axes.plot(x, y, c, linewidth=3)

    axes.set_xlim([float(_) for _ in xlim])

    if ylim:

        axes.set_ylim([float(_) for _ in ylim])

    axes.set_xlabel('x')

    axes.set_ylabel('f(x)')

    axes.grid(True)

    if fig:

        if file:

            pylab.savefig(file, dpi=dpi)

        else:

            pylab.show()

def default_color_function(ctx, z):

    if ctx.isinf(z):

        return (1.0, 1.0, 1.0)

    if ctx.isnan(z):

        return (0.5, 0.5, 0.5)

    pi = 3.1415926535898

    a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)

    a = (a + 0.5) % 1.0

    b = 1.0 - float(1/(1.0+abs(z)**0.3))

    return hls_to_rgb(a, b, 0.8)

blue_orange_colors = [

  (-1.0,  (0.0, 0.0, 0.0)),

  (-0.95, (0.1, 0.2, 0.5)),  # dark blue

  (-0.5,  (0.0, 0.5, 1.0)),  # blueish

  (-0.05, (0.4, 0.8, 0.8)),  # cyanish

  ( 0.0,  (1.0, 1.0, 1.0)),

  ( 0.05, (1.0, 0.9, 0.3)),  # yellowish

  ( 0.5,  (0.9, 0.5, 0.0)),  # orangeish

  ( 0.95, (0.7, 0.1, 0.0)),  # redish

  ( 1.0,  (0.0, 0.0, 0.0)),

  ( 2.0,  (0.0, 0.0, 0.0)),

]

def phase_color_function(ctx, z):

    if ctx.isinf(z):

        return (1.0, 1.0, 1.0)

    if ctx.isnan(z):

        return (0.5, 0.5, 0.5)

    pi = 3.1415926535898

    w = float(ctx.arg(z)) / pi

    w = max(min(w, 1.0), -1.0)

    for i in range(1,len(blue_orange_colors)):

        if blue_orange_colors[i][0] > w:

            a, (ra, ga, ba) = blue_orange_colors[i-1]

            b, (rb, gb, bb) = blue_orange_colors[i]

            s = (w-a) / (b-a)

            return ra+(rb-ra)*s, ga+(gb-ga)*s, ba+(bb-ba)*s

def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,

    verbose=False, file=None, dpi=None, axes=None):

    """

    Plots the given complex-valued function *f* over a rectangular part

    of the complex plane specified by the pairs of intervals *re* and *im*.

    For example::

        cplot(lambda z: z, [-2, 2], [-10, 10])

        cplot(exp)

        cplot(zeta, [0, 1], [0, 50])

    By default, the complex argument (phase) is shown as color (hue) and

    the magnitude is show as brightness. You can also supply a

    custom color function (*color*). This function should take a

    complex number as input and return an RGB 3-tuple containing

    floats in the range 0.0-1.0.

    Alternatively, you can select a builtin color function by passing

    a string as *color*:

      * "default" - default color scheme

      * "phase" - a color scheme that only renders the phase of the function,

        with white for positive reals, black for negative reals, gold in the

        upper half plane, and blue in the lower half plane.

    To obtain a sharp image, the number of points may need to be

    increased to 100,000 or thereabout. Since evaluating the

    function that many times is likely to be slow, the 'verbose'

    option is useful to display progress.

    .. note :: This function requires matplotlib (pylab).

    """

    if color is None or color == "default":

        color = ctx.default_color_function

    if color == "phase":

        color = ctx.phase_color_function

    import pylab

    if file:

        axes = None

    fig = None

    if not axes:

        fig = pylab.figure()

        axes = fig.add_subplot(111)

    rea, reb = re

    ima, imb = im

    dre = reb - rea

    dim = imb - ima

    M = int(ctx.sqrt(points*dre/dim)+1)

    N = int(ctx.sqrt(points*dim/dre)+1)

    x = pylab.linspace(rea, reb, M)

    y = pylab.linspace(ima, imb, N)

    # Note: we have to be careful to get the right rotation.

    # Test with these plots:

    #  cplot(lambda z: z if z.real < 0 else 0)

    #  cplot(lambda z: z if z.imag < 0 else 0)

    w = pylab.zeros((N, M, 3))

    for n in xrange(N):

        for m in xrange(M):

            z = ctx.mpc(x[m], y[n])

            try:

                v = color(f(z))

            except ctx.plot_ignore:

                v = (0.5, 0.5, 0.5)

            w[n,m] = v

        if verbose:

            print(str(n) + ' of ' + str(N))

    rea, reb, ima, imb = [float(_) for _ in [rea, reb, ima, imb]]

    axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')

    axes.set_xlabel('Re(z)')

    axes.set_ylabel('Im(z)')

    if fig:

        if file:

            pylab.savefig(file, dpi=dpi)

        else:

            pylab.show()

def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \

          wireframe=False, file=None, dpi=None, axes=None):

    """

    Plots the surface defined by `f`.

    If `f` returns a single component, then this plots the surface

    defined by `z = f(x,y)` over the rectangular domain with

    `x = u` and `y = v`.

    If `f` returns three components, then this plots the parametric

    surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.

    For example, to plot a simple function::

        >>> from mpmath import *

        >>> f = lambda x, y: sin(x+y)*cos(y)

        >>> splot(f, [-pi,pi], [-pi,pi])    # doctest: +SKIP

    Plotting a donut::

        >>> r, R = 1, 2.5

        >>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]

        >>> splot(f, [0, 2*pi], [0, 2*pi])    # doctest: +SKIP

    .. note :: This function requires matplotlib (pylab) 0.98.5.3 or higher.

    """

    import pylab

    import mpl_toolkits.mplot3d as mplot3d

    if file:

        axes = None

    fig = None

    if not axes:

        fig = pylab.figure()

        axes = mplot3d.axes3d.Axes3D(fig)

    ua, ub = u

    va, vb = v

    du = ub - ua

    dv = vb - va

    if not isinstance(points, (list, tuple)):

        points = [points, points]

    M, N = points

    u = pylab.linspace(ua, ub, M)

    v = pylab.linspace(va, vb, N)

    x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]

    xab, yab, zab = [[0, 0] for i in xrange(3)]

    for n in xrange(N):

        for m in xrange(M):

            fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))

            try:

                x[m,n], y[m,n], z[m,n] = fdata

            except TypeError:

                x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata

            for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:

                if c < cab[0]:

                    cab[0] = c

                if c > cab[1]:

                    cab[1] = c

    if wireframe:

        axes.plot_wireframe(x, y, z, rstride=4, cstride=4)

    else:

        axes.plot_surface(x, y, z, rstride=4, cstride=4)

    axes.set_xlabel('x')

    axes.set_ylabel('y')

    axes.set_zlabel('z')

    if keep_aspect:

        dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]

        maxd = max(dx, dy, dz)

        if dx < maxd:

            delta = maxd - dx

            axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)

        if dy < maxd:

            delta = maxd - dy

            axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)

        if dz < maxd:

            delta = maxd - dz

            axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)

    if fig:

        if file:

            pylab.savefig(file, dpi=dpi)

        else:

            pylab.show()

VisualizationMethods.plot = plot

VisualizationMethods.default_color_function = default_color_function

VisualizationMethods.phase_color_function = phase_color_function

VisualizationMethods.cplot = cplot

VisualizationMethods.splot = splot

本站交流只代表网友个人观点,与本站立场无关
欢迎反馈与建议,请联系电邮
2024© 车视网