Figure 3.2

Joint and conditional probabilities

In [1]:
%pylab inline
Populating the interactive namespace from numpy and matplotlib

In [2]:
%load http://www.astroml.org/_downloads/fig_conditional_probability.py
In [4]:
"""
Joint and Conditional Probabilities
-----------------------------------
Figure 3.2.

An example of a two-dimensional probability distribution. The color-coded
panel shows p(x, y). The two panels to the left and below show marginal
distributions in x and y (see eq. 3.8). The three panels to the right show
the conditional probability distributions p(x|y) (see eq. 3.7) for three
different values of y (as marked in the left panel).
"""
# Author: Jake VanderPlas
# License: BSD
#   The figure produced by this code is published in the textbook
#   "Statistics, Data Mining, and Machine Learning in Astronomy" (2013)
#   For more information, see http://astroML.github.com
#   To report a bug or issue, use the following forum:
#    https://groups.google.com/forum/#!forum/astroml-general
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import NullFormatter, NullLocator, MultipleLocator

#----------------------------------------------------------------------
# This function adjusts matplotlib settings for a uniform feel in the textbook.
# Note that with usetex=True, fonts are rendered with LaTeX.  This may
# result in an error if LaTeX is not installed on your system.  In that case,
# you can set usetex to False.
from astroML.plotting import setup_text_plots
setup_text_plots(fontsize=8, usetex=True)

def banana_distribution(N=10000):
    """This generates random points in a banana shape"""
    # create a truncated normal distribution
    theta = np.random.normal(0, np.pi / 8, 10000)
    theta[theta >= np.pi / 4] /= 2
    theta[theta <= -np.pi / 4] /= 2
    # define the curve parametrically
    r = np.sqrt(1. / abs(np.cos(theta) ** 2 - np.sin(theta) ** 2))
    r += np.random.normal(0, 0.08, size=10000)
    x = r * np.cos(theta + np.pi / 4)
    y = r * np.sin(theta + np.pi / 4)
    return (x, y)


#------------------------------------------------------------
# Generate the data and compute the normalized 2D histogram
np.random.seed(1)
x, y = banana_distribution(10000)

Ngrid = 41
grid = np.linspace(0, 2, Ngrid + 1)

H, xbins, ybins = np.histogram2d(x, y, grid)
H /= np.sum(H)

#------------------------------------------------------------
# plot the result
fig = plt.figure(figsize=(10, 5))

# define axes
ax_Pxy = plt.axes((0.2, 0.34, 0.27, 0.52))
ax_Px = plt.axes((0.2, 0.14, 0.27, 0.2))
ax_Py = plt.axes((0.1, 0.34, 0.1, 0.52))
ax_cb = plt.axes((0.48, 0.34, 0.01, 0.52))
ax_Px_y = [plt.axes((0.65, 0.62, 0.32, 0.23)),
           plt.axes((0.65, 0.38, 0.32, 0.23)),
           plt.axes((0.65, 0.14, 0.32, 0.23))]

# set axis label formatters
ax_Px_y[0].xaxis.set_major_formatter(NullFormatter())
ax_Px_y[1].xaxis.set_major_formatter(NullFormatter())

ax_Pxy.xaxis.set_major_formatter(NullFormatter())
ax_Pxy.yaxis.set_major_formatter(NullFormatter())

ax_Px.yaxis.set_major_formatter(NullFormatter())
ax_Py.xaxis.set_major_formatter(NullFormatter())

# draw the joint probability
plt.axes(ax_Pxy)
H *= 1000
plt.imshow(H, interpolation='nearest', origin='lower', aspect='auto',
           extent=[0, 2, 0, 2], cmap=plt.cm.binary)

cb = plt.colorbar(cax=ax_cb)
cb.set_label('$p(x, y)$')
plt.text(0, 1.02, r'$\times 10^{-3}$',
         transform=ax_cb.transAxes)

# draw p(x) distribution
ax_Px.plot(xbins[1:], H.sum(0), '-k', drawstyle='steps')

# draw p(y) distribution
ax_Py.plot(H.sum(1), ybins[1:], '-k', drawstyle='steps')

# define axis limits
ax_Pxy.set_xlim(0, 2)
ax_Pxy.set_ylim(0, 2)
ax_Px.set_xlim(0, 2)
ax_Py.set_ylim(0, 2)

# label axes
ax_Pxy.set_xlabel('$x$')
ax_Pxy.set_ylabel('$y$')
ax_Px.set_xlabel('$x$')
ax_Px.set_ylabel('$p(x)$')
ax_Px.yaxis.set_label_position('right')
ax_Py.set_ylabel('$y$')
ax_Py.set_xlabel('$p(y)$')
ax_Py.xaxis.set_label_position('top')


# draw marginal probabilities
iy = [3 * Ngrid / 4, Ngrid / 2, Ngrid / 4]
colors = 'rgc'
axis = ax_Pxy.axis()
for i in range(3):
    # overplot range on joint probability
    ax_Pxy.plot([0, 2, 2, 0],
                [ybins[iy[i] + 1], ybins[iy[i] + 1],
                 ybins[iy[i]], ybins[iy[i]]], c=colors[i], lw=1)
    Px_y = H[iy[i]] / H[iy[i]].sum()
    ax_Px_y[i].plot(xbins[1:], Px_y, drawstyle='steps', c=colors[i])
    ax_Px_y[i].yaxis.set_major_formatter(NullFormatter())
    ax_Px_y[i].set_ylabel('$p(x | %.1f)$' % ybins[iy[i]])
ax_Pxy.axis(axis)

ax_Px_y[2].set_xlabel('$x$')

ax_Pxy.set_title('Joint Probability')
ax_Px_y[0].set_title('Conditional Probability')

plt.show()

The modified figure

This figure reminds me of Figure 2.2 of MacKay's Inference book: http://www.inference.phy.cam.ac.uk/itprnn/book.pdf There is a lot going on in this figure:

  • How to make an arbitraty, random 2D distribution, $p(x,y)$
  • How to make a 2D histogram
  • How to share axes
  • How to put in a color bar
  • Complex subplot formatting
  • Discrete integrals along one dimension of an array
  • Concept of joint and conditional probabilities
  • Notation for joint probability and conditional probability
In []:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import NullFormatter, NullLocator, MultipleLocator

First, let's make the asymmetric distribution that will serve as $p(x,y)$: $$\theta \sim \mathcal{N}(\mu, \sigma)$$ We will truncate the distribution where $\theta > \mid 2\sigma \mid$, by dividing the wings by 2.

In [27]:
theta = np.random.normal(0, np.pi / 8, 10000)
theta1= np.random.normal(0, np.pi / 8, 10000)
theta[theta >= np.pi / 4] /= 2
theta[theta <= -np.pi / 4] /= 2

fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(111)
text_kwargs = dict(fontsize=22)

ax.hist(theta1, 20, histtype='stepfilled', fc='#FF0000', normed=True, alpha=0.4)
ax.hist(theta, 20, histtype='stepfilled', fc='#CCCCCC', normed=True, alpha=0.8)

ax.set_xlim(-np.pi, np.pi)

ax.set_title('Truncated Normal, half the wings', **text_kwargs)
ax.set_xlabel('$\\theta$', **text_kwargs)
ax.set_ylabel('$p(\\theta)$', **text_kwargs)
Out[27]:
<matplotlib.text.Text at 0x10e13acd0>

Next, make a parametric curve: $$r = \sqrt{\frac{1}{\mid \cos^2{\theta} - \sin^2{\theta} \mid}} + \epsilon$$ $$x = r \cos{\theta + \pi/4}$$ $$y = r \sin{\theta + \pi/4}$$ $$\epsilon \sim \mathcal{N}(0, 0.08)$$

In [52]:
# define the curve parametrically
r = np.sqrt(1. / abs(np.cos(theta) ** 2 - np.sin(theta) ** 2))
r += np.random.normal(0, 0.08, size=10000)
x = r * np.cos(theta+np.pi/4)
y = r * np.sin(theta+np.pi/4)
shape((x, y))
Out[52]:
(2, 10000)

I don't fully appreciate the detailed interplay of the terms in the parametric curve, but the basic gist is there-- you have a distribution on $\theta$. You throw that into a function that calculates amplitudes based on theta. You noise up the amplitudes. The positions $(x,y)$ depend on the (slightly noisy) amplitude, $r$, and trigonometric functions of the $\theta$ parameter.

One thing I don't understand is how the function includes the desired number of samples, $N=10000$. $N$ doesn't seem to be a user-definable parameter.

Recall that all of this is merely to generate synthetic data. In reality the data comes down from heaven and we just deal with it.

Let's show just scatter plot, and then the "Hess diagram", without any of the other stuff.

In [53]:
Ngrid = 41
grid = np.linspace(0, 2, Ngrid + 1)

H, xbins, ybins = np.histogram2d(x, y, grid)
H /= np.sum(H)

#------------------------------------------------------------
# plot the result
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(121)
#For clarity, only show the first 1000 data points (10% of the data)
plt.scatter(x[0:1000],y[0:1000],s=20, c='#ff0000', alpha=0.1)
plt.xlim((0, 2))
plt.ylim((0, 2))

ax = fig.add_subplot(122)
plt.imshow(H, interpolation='nearest', origin='lower', aspect='auto',
           extent=[0, 2, 0, 2], cmap=plt.cm.binary)
plt.show()

Ok, the next step is to figure out the complex figure axis formatting and how to share axes. The key idea for the complex formatting is using the rect argument in the call to plt.axes(). Usually if you call plt.axes() you just get a full subplot filling the area. But if you give plt.axes() a 4 element tuple of normalized coordinates, it places the plot in a rectangle with (left, bottom, width, height).

In [66]:
fig = plt.figure(figsize=(10, 5))

# define axes
ax_Pxy = plt.axes((0.2, 0.34, 0.27, 0.52))
ax_Px = plt.axes((0.2, 0.14, 0.27, 0.2))
ax_Py = plt.axes((0.1, 0.34, 0.1, 0.52))
ax_cb = plt.axes((0.48, 0.34, 0.01, 0.52))
ax_Px_y = [plt.axes((0.65, 0.62, 0.32, 0.23)),
           plt.axes((0.65, 0.38, 0.32, 0.23)),
           plt.axes((0.65, 0.14, 0.32, 0.23))]

# set axis label formatters
ax_Px_y[0].xaxis.set_major_formatter(NullFormatter())
ax_Px_y[1].xaxis.set_major_formatter(NullFormatter())

ax_Pxy.xaxis.set_major_formatter(NullFormatter())
ax_Pxy.yaxis.set_major_formatter(NullFormatter())

''' (MGS edit) Uncomment to get rid of color bar tick weirdness.'''
#ax_cb.yaxis.set_major_formatter(NullFormatter())
#ax_cb.xaxis.set_major_formatter(NullFormatter())
'''--------------------------------------------------'''

ax_Px.yaxis.set_major_formatter(NullFormatter())
ax_Py.xaxis.set_major_formatter(NullFormatter())

plt.show()
In [81]:
# plot the result
fig = plt.figure(figsize=(12, 6))

# define axes
ax_Pxy = plt.axes((0.2, 0.34, 0.27, 0.52))
ax_Px = plt.axes((0.2, 0.14, 0.27, 0.2))
ax_Py = plt.axes((0.1, 0.34, 0.1, 0.52))
ax_cb = plt.axes((0.48, 0.34, 0.01, 0.52))
ax_Px_y = [plt.axes((0.65, 0.62, 0.32, 0.23)),
           plt.axes((0.65, 0.38, 0.32, 0.23)),
           plt.axes((0.65, 0.14, 0.32, 0.23))]

# set axis label formatters
ax_Px_y[0].xaxis.set_major_formatter(NullFormatter())
ax_Px_y[1].xaxis.set_major_formatter(NullFormatter())

ax_Pxy.xaxis.set_major_formatter(NullFormatter())
ax_Pxy.yaxis.set_major_formatter(NullFormatter())

ax_Px.yaxis.set_major_formatter(NullFormatter())
ax_Py.xaxis.set_major_formatter(NullFormatter())

# draw the joint probability
plt.axes(ax_Pxy)
H *= 1000
plt.imshow(H, interpolation='nearest', origin='lower', aspect='auto',
           extent=[0, 2, 0, 2], cmap=plt.cm.binary)

cb = plt.colorbar(cax=ax_cb)
cb.set_label('$p(x, y)$')
plt.text(0, 1.02, r'$\times 10^{-3}$',
         transform=ax_cb.transAxes) 
''' This transform argument is sort of sophisticated...
    -gully'''

# draw p(x) distribution
ax_Px.plot(xbins[1:], H.sum(0), '-k', drawstyle='steps')

# draw p(y) distribution
ax_Py.plot(H.sum(1), ybins[1:], '-k', drawstyle='steps')

''' Summing along an axis is as easy as H.sum()
    The argument of the sum is the dimension over which 
    the sum is performed.  0 is the x-axis, 1 is the y-axis.
    -gully'''

print 'shape H: ', shape(H)
print 'shape H.sum(0): ', shape(H.sum(0))
print 'shape H.sum(1): ', shape(H.sum(1))


# define axis limits
ax_Pxy.set_xlim(0, 2)
ax_Pxy.set_ylim(0, 2)
ax_Px.set_xlim(0, 2)
ax_Py.set_ylim(0, 2)

# label axes
ax_Pxy.set_xlabel('$x$')
ax_Pxy.set_ylabel('$y$')
ax_Px.set_xlabel('$x$')
ax_Px.set_ylabel('$p(x)$')
ax_Px.yaxis.set_label_position('right')
ax_Py.set_ylabel('$y$')
ax_Py.set_xlabel('$p(y)$')
ax_Py.xaxis.set_label_position('top')

plt.show()
shape H:  (41, 41)
shape H.sum(0):  (41,)
shape H.sum(1):  (41,)

So adding the plot values and color bars, and setting the x and y limits, get rid of the junk-looking colorbar. Although the color bar label is weird because it says: $\times 10^{7-3}$. The right hand side panel still has labels on the y-axis, which I don't understand.

In [86]:
# plot the result
fig = plt.figure(figsize=(12, 6))

# define axes
ax_Pxy = plt.axes((0.2, 0.34, 0.27, 0.52))
ax_Px = plt.axes((0.2, 0.14, 0.27, 0.2))
ax_Py = plt.axes((0.1, 0.34, 0.1, 0.52))
ax_cb = plt.axes((0.48, 0.34, 0.01, 0.52))
ax_Px_y = [plt.axes((0.65, 0.62, 0.32, 0.23)),
           plt.axes((0.65, 0.38, 0.32, 0.23)),
           plt.axes((0.65, 0.14, 0.32, 0.23))]

# set axis label formatters
ax_Px_y[0].xaxis.set_major_formatter(NullFormatter())
ax_Px_y[1].xaxis.set_major_formatter(NullFormatter())

ax_Pxy.xaxis.set_major_formatter(NullFormatter())
ax_Pxy.yaxis.set_major_formatter(NullFormatter())

ax_Px.yaxis.set_major_formatter(NullFormatter())
ax_Py.xaxis.set_major_formatter(NullFormatter())

# draw the joint probability
plt.axes(ax_Pxy)
H *= 1000
plt.imshow(H, interpolation='nearest', origin='lower', aspect='auto',
           extent=[0, 2, 0, 2], cmap=plt.cm.binary)

cb = plt.colorbar(cax=ax_cb)
cb.set_label('$p(x, y)$')
plt.text(0, 1.02, r'$\times 10^{-3}$',
         transform=ax_cb.transAxes) 
''' This transform argument is sort of sophisticated...
    -gully'''

# draw p(x) distribution
ax_Px.plot(xbins[1:], H.sum(0), '-k', drawstyle='steps')

# draw p(y) distribution
ax_Py.plot(H.sum(1), ybins[1:], '-k', drawstyle='steps')

''' Summing along an axis is as easy as H.sum()
    The argument of the sum is the dimension over which 
    the sum is performed.  0 is the x-axis, 1 is the y-axis.
    -gully'''

print 'shape H: ', shape(H)
print 'shape H.sum(0): ', shape(H.sum(0))
print 'shape H.sum(1): ', shape(H.sum(1))


# define axis limits
ax_Pxy.set_xlim(0, 2)
ax_Pxy.set_ylim(0, 2)
ax_Px.set_xlim(0, 2)
ax_Py.set_ylim(0, 2)

# label axes
ax_Pxy.set_xlabel('$x$')
ax_Pxy.set_ylabel('$y$')
ax_Px.set_xlabel('$x$')
ax_Px.set_ylabel('$p(x)$')
ax_Px.yaxis.set_label_position('right')
ax_Py.set_ylabel('$y$')
ax_Py.set_xlabel('$p(y)$')
ax_Py.xaxis.set_label_position('top')

# draw marginal probabilities
iy = [3 * Ngrid / 4, Ngrid / 2, Ngrid / 4]
''' which slices of the p(x,y) to take.  
    Here we define the index 0.25, 0.50, 0.75 of the range
    p(x | y = 0.5), p(x | y = 1.0), p(x | y = 1.5)
    -gully '''

colors = 'rgc'
#print 'type(colors)', type(colors)
''' apparently a string is also a list of characters
    -gully''' 

axis = ax_Pxy.axis()
for i in range(3):
    # overplot range on joint probability
    ax_Pxy.plot([0, 2, 2, 0],
                [ybins[iy[i] + 1], ybins[iy[i] + 1],
                 ybins[iy[i]], ybins[iy[i]]], c=colors[i], lw=1)
    Px_y = H[iy[i]] / H[iy[i]].sum() 
    ''' Px_y is the same as P( x | y = something )
        -gully '''
    ax_Px_y[i].plot(xbins[1:], Px_y, drawstyle='steps', c=colors[i])
    ax_Px_y[i].yaxis.set_major_formatter(NullFormatter())
    ''' a-ha, this is why the intermediate plots showed labels on
        the y-axis for the right-hand-side plot.
        -gully'''
    ax_Px_y[i].set_ylabel('$p(x | %.1f)$' % ybins[iy[i]])
ax_Pxy.axis(axis)
''' I don't understand what the above axis call does... just a reset?'''

ax_Px_y[2].set_xlabel('$x$')

ax_Pxy.set_title('Joint Probability')
ax_Px_y[0].set_title('Conditional Probability')

plt.show()
shape H:  (41, 41)
shape H.sum(0):  (41,)
shape H.sum(1):  (41,)

Recap

Recap of what we learned from disecting this figure.

  • How to make an arbitraty, random 2D distribution, $p(x,y)$, admittedly with some questions about how the parametric distribution was crafted in the first place
  • How to make a 2D histogram- this is easy with numpy's histogram2d()
  • How to share axes- really this is the same as complex subplot formatting.
  • How to put in a color bar- ditto, except there is a special command, plt.colorbar()
  • Complex subplot formatting- lots of this, main points are summarized in comments.
  • Discrete integrals along one dimension of an array. Easy with .sum(), with argument 0 or 1 for x or y, respectively. Similarly for higher dimensions.
  • Concept of joint and conditional probabilities- compare to the MacKay Figure 2.2.
  • Notation for joint probability and conditional probability- I prefer the cumbersome albeit explicit notation of p( x | y = 0.5 ) rather than p( x | 0.5 ). I guess really it should be symmetric, and the arguments should be statements, not variables: p( x = a | y = 0.5 ), where a is a variable in a specified domain.