Tips and Tricks for Matplotlib

In [4]:
y = randn(1000)
plot(y)
Out[4]:
[<matplotlib.lines.Line2D at 0xa6df9cc>]
In [5]:
y = randn(1000)
plot(y, 'g^')
Out[5]:
[<matplotlib.lines.Line2D at 0xac592ac>]
In [7]:
y1 = randn(1000)
y2 = randn(1000)
plot(y1, 'r,', y2, 'b,')
Out[7]:
[<matplotlib.lines.Line2D at 0xaeb2d4c>,
 <matplotlib.lines.Line2D at 0xaeb732c>]
In [8]:
y1 = randn(1000)
y2 = randn(1000)
plot(y1, 'r-', y2, 'b-.')
Out[8]:
[<matplotlib.lines.Line2D at 0xaf864ac>,
 <matplotlib.lines.Line2D at 0xaf86a4c>]
In [10]:
x = randn(10000)
hist(x, 100)
Out[10]:
(array([  1,   0,   0,   0,   0,   0,   2,   3,   2,   2,   3,   4,   3,
         6,  11,   6,  17,  13,  19,  21,  26,  16,  33,  41,  44,  38,
        61,  54,  77,  85,  95, 110, 124, 116, 147, 134, 152, 144, 160,
       191, 177, 209, 200, 233, 212, 234, 251, 269, 263, 262, 306, 281,
       283, 255, 253, 265, 268, 253, 258, 248, 243, 238, 238, 202, 159,
       169, 176, 213, 137, 150, 123, 133, 115, 103,  79,  81,  69,  64,
        52,  32,  43,  36,  38,  31,  22,  21,  12,  16,  13,   6,   8,
         6,   7,   3,   6,   4,   4,   3,   2,   2]),
 array([-3.63910502, -3.56971213, -3.50031924, -3.43092635, -3.36153346,
       -3.29214057, -3.22274768, -3.15335479, -3.0839619 , -3.01456901,
       -2.94517613, -2.87578324, -2.80639035, -2.73699746, -2.66760457,
       -2.59821168, -2.52881879, -2.4594259 , -2.39003301, -2.32064012,
       -2.25124723, -2.18185434, -2.11246145, -2.04306857, -1.97367568,
       -1.90428279, -1.8348899 , -1.76549701, -1.69610412, -1.62671123,
       -1.55731834, -1.48792545, -1.41853256, -1.34913967, -1.27974678,
       -1.2103539 , -1.14096101, -1.07156812, -1.00217523, -0.93278234,
       -0.86338945, -0.79399656, -0.72460367, -0.65521078, -0.58581789,
       -0.516425  , -0.44703211, -0.37763922, -0.30824634, -0.23885345,
       -0.16946056, -0.10006767, -0.03067478,  0.03871811,  0.108111  ,
        0.17750389,  0.24689678,  0.31628967,  0.38568256,  0.45507545,
        0.52446833,  0.59386122,  0.66325411,  0.732647  ,  0.80203989,
        0.87143278,  0.94082567,  1.01021856,  1.07961145,  1.14900434,
        1.21839723,  1.28779012,  1.35718301,  1.42657589,  1.49596878,
        1.56536167,  1.63475456,  1.70414745,  1.77354034,  1.84293323,
        1.91232612,  1.98171901,  2.0511119 ,  2.12050479,  2.18989768,
        2.25929056,  2.32868345,  2.39807634,  2.46746923,  2.53686212,
        2.60625501,  2.6756479 ,  2.74504079,  2.81443368,  2.88382657,
        2.95321946,  3.02261235,  3.09200523,  3.16139812,  3.23079101,
        3.3001839 ]),
 <a list of 100 Patch objects>)
In [11]:
"""
Show how to make date plots in matplotlib using date tick locators and
formatters.  See major_minor_demo1.py for more information on
controlling major and minor ticks

All matplotlib date plotting is done by converting date instances into
days since the 0001-01-01 UTC.  The conversion, tick locating and
formatting is done behind the scenes so this is most transparent to
you.  The dates module provides several converter functions date2num
and num2date

"""
import datetime
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.cbook as cbook

years    = mdates.YearLocator()   # every year
months   = mdates.MonthLocator()  # every month
yearsFmt = mdates.DateFormatter('%Y')

# load a numpy record array from yahoo csv data with fields date,
# open, close, volume, adj_close from the mpl-data/example directory.
# The record array stores python datetime.date as an object array in
# the date column
datafile = cbook.get_sample_data('goog.npy')
r = np.load(datafile).view(np.recarray)

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(r.date, r.adj_close)


# format the ticks
ax.xaxis.set_major_locator(years)
ax.xaxis.set_major_formatter(yearsFmt)
ax.xaxis.set_minor_locator(months)

datemin = datetime.date(r.date.min().year, 1, 1)
datemax = datetime.date(r.date.max().year+1, 1, 1)
ax.set_xlim(datemin, datemax)

# format the coords message box
def price(x): return '$%1.2f'%x
ax.format_xdata = mdates.DateFormatter('%Y-%m-%d')
ax.format_ydata = price
ax.grid(True)

# rotates and right aligns the x labels, and moves the bottom of the
# axes up to make room for them
fig.autofmt_xdate()

plt.show()
In [20]:
import pandas as pd
%cd 'data'
filename = 'worldcitiespop.txt'
data = pd.read_csv(filename)
plot(data.Longitude, data.Latitude, 'r,')
%cd..
/home/jeff/ipython/data
/home/jeff/ipython

In [22]:
# Take from http://stackoverflow.com/questions/2369492/generate-a-heatmap-in-matplotlib-using-a-scatter-data-set
# This makes a 50x50 heatmap.
# If you want, say, 512x384, you can put bins=(512, 384) in the call to histogram2d.

import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

heatmap, xedges, yedges = np.histogram2d(x, y, bins=(50, 50))
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.imshow(heatmap, extent=extent)
plt.show()
In [23]:
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

heatmap, xedges, yedges = np.histogram2d(x, y, bins=(512, 384))
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.imshow(heatmap, extent=extent)
plt.show()
In [1]:
x = linspace(0, 2 * pi, 1000)
y = 1 + 2 * cos(5 * x)
subplot(1,2,1)
plot(x,y)
subplot(1,2,2,polar=True)
polar(x,y)
Out[1]:
[<matplotlib.lines.Line2D at 0x9aabecc>]
In [2]:
import urllib2
from PIL import Image

png= urllib2.urlopen('http://ipython.rossant.net/squirrel.png')
im = imread(png)
im.shape
img = Image.fromarray((im * 255).astype('uint8'))
im = array(img)
imshow(im)
Out[2]:
<matplotlib.image.AxesImage at 0x9b7a44c>
In [4]:
from scipy.cluster.vq import *

M = im[:,:,0].ravel()
centroids, _ = kmeans(M, 4)
qnt, _ = vq(M, centroids)
clustered = centroids[reshape(qnt, (300, 300))]
cmap = matplotlib.colors.ListedColormap([(0,0.2,0.3),(0.85,0.1,0.13),(0.44,0.6,0.6),(1.0,0.9,0.65)])
imshow(clustered, cmap=cmap)
Out[4]:
<matplotlib.image.AxesImage at 0xa06966c>
In [6]:
import pandas as pd
%cd 'data'
filename = 'worldcitiespop.txt'
data = pd.read_csv(filename)
plot(data.Longitude, data.Latitude, 'r,')
%cd..
/home/jeff/ipython/data
/home/jeff/ipython

In [7]:
locations = data[['Longitude','Latitude']].as_matrix()
locations
Out[7]:
array([[  1.4666667,  42.4833333],
       [  1.5      ,  42.4666667],
       [  1.5      ,  42.4666667],
       ..., 
       [ 31.0105556, -17.7588889],
       [ 27.9333333, -20.0333333],
       [ 30.0333333, -20.3333333]])
In [8]:
population = data.Population
population
Out[8]:
0       NaN
1       NaN
2       NaN
3       NaN
4       NaN
5       NaN
6     20430
7       NaN
8       NaN
9       NaN
10      NaN
11      NaN
12      NaN
13      NaN
14      NaN
...
3173943      NaN
3173944      NaN
3173945      NaN
3173946      NaN
3173947      NaN
3173948      NaN
3173949      NaN
3173950      NaN
3173951      NaN
3173952      NaN
3173953      NaN
3173954      NaN
3173955      NaN
3173956      NaN
3173957    79876
Name: Population, Length: 3173958
In [10]:
from mpl_toolkits.basemap import Basemap

m = Basemap(projection='mill', llcrnrlat=-65, urcrnrlat=85, llcrnrlon=-180, urcrnrlon=180)
x,y = m(locations[:,0],locations[:,1])
x,y
/usr/lib/pymodules/python2.7/mpl_toolkits/__init__.py:2: UserWarning: Module dap was already imported from None, but /usr/lib/python2.7/dist-packages is being added to sys.path
  __import__('pkg_resources').declare_namespace(__name__)

Out[10]:
(array([ 20178163.19056664,  20181869.68266964,  20181869.68266964, ...,
        23463292.20268273,  23121120.85588821,  23354630.09188604]),
 array([ 13518771.05872486,  13516536.17283419,  13516536.17283419, ...,
         6495391.61961934,   6233407.11699923,   6198678.69292821]))
In [16]:
x0, y0 = m(-180, -65)
x1, y1 = m(180, 85)
x1, y1
Out[16]:
(40030154.742485225, 21534769.36545709)
In [17]:
weights = population.copy()
weights[isnan(weights)] = 1000
h, _, _ = histogram2d(x, y, weights = weights, bins=(linspace(x0, x1, 500), linspace(y0, y1, 500)))
h
Out[17]:
array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]])
In [21]:
import scipy.ndimage.filters

h[h==0] = 1
z = scipy.ndimage.filters.gaussian_filter(log(h.T), 1)

m.drawcoastlines()
m.imshow(z, origin='lower', extent=[x0,x1,y0,y1], cmap=get_cmap('Reds'))
Out[21]:
<matplotlib.image.AxesImage at 0x1743b34c>

The %load magic lets you load code from URLs or local files:

In [1]:
%load http://matplotlib.sourceforge.net/mpl_examples/pylab_examples/integral_demo.py
In [2]:
#!/usr/bin/env python

# implement the example graphs/integral from pyx
from pylab import *
from matplotlib.patches import Polygon

def func(x):
    return (x-3)*(x-5)*(x-7)+85

ax = subplot(111)

a, b = 2, 9 # integral area
x = arange(0, 10, 0.01)
y = func(x)
plot(x, y, linewidth=1)

# make the shaded region
ix = arange(a, b, 0.01)
iy = func(ix)
verts = [(a,0)] + list(zip(ix,iy)) + [(b,0)]
poly = Polygon(verts, facecolor='0.8', edgecolor='k')
ax.add_patch(poly)

text(0.5 * (a + b), 30,
     r"$\int_a^b f(x)\mathrm{d}x$", horizontalalignment='center',
     fontsize=20)

axis([0,10, 0, 180])
figtext(0.9, 0.05, 'x')
figtext(0.1, 0.9, 'y')
ax.set_xticks((a,b))
ax.set_xticklabels(('a','b'))
ax.set_yticks([])
show()
In []: