MathJax

MathJax

Thursday, October 8, 2015

A Simple Waterfall Plot in Python

Waterfall Plot

A Simple Waterfall Plot

I was reviewing my notes from a course I took a year or so ago on, using Fourier for signal analysis and all sorts of fun stuff. The course was taught in MATLAB, and a particular kind of plot was just thrown in with a call to some function waterfall(). I remember octave didn't seem to have this function at all, and python with matplotlib didn't seem to make things any too easy either. I don't think I managed to make things work in any language I had access to at the time, so I decided that I would go back and give it another shot.

In [5]:
%matplotlib inline
In [6]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection
from matplotlib.colors import colorConverter
from mpl_toolkits.mplot3d import Axes3D

So, after the imports, on to setting up the demo from class. This was supposed to be a simple simulation of a plane flying a somewhat unlikely course and getting hit with a radar blip.

In [7]:
T = 60.
n = 512
t = np.linspace(-T/2., T/2., n+1)
t = t[0:n]
# There's a function to set up the frequencies, but doing it by hand seems to help me think 
# things through.
k = np.array([(2. * np.pi)*i if i < n/2 else (2. * np.pi) * (i - n) 
  for i in range(n)])

ks = np.fft.fftshift(k)
slc = np.arange(0, 10, 0.5)
# I haven't quite figured out how to use the meshgrid function in numpy
T, S = np.meshgrid(t, slc)
K, S = np.meshgrid(k, slc)

# Now, we have a plane flying back and forth in a sine wave and getting painted by a radar pulse
# which is a hyperbolic secant (1/cosh)
U = 1./np.cosh(T - 10. * np.sin(S)) * np.exp(1j * 0. * T)

def waterfall(X, Y, Z, nslices):

  # Function to generate formats for facecolors
  cc = lambda arg: colorConverter.to_rgba(arg, alpha=0.3)
  # This is just wrong. There must be some way to use the meshgrid or why bother.
  verts = []
  for i in range(nslices):
    verts.append(list(zip(X[i], Z[i])))

  xmin = np.floor(np.min(X))
  xmax = np.ceil(np.max(X))
  ymin = np.floor(np.min(Y))
  ymax = np.ceil(np.max(Y))
  zmin = np.floor(np.min(Z.real))
  zmax = np.ceil(np.max(np.abs(Z)))

  fig=plt.figure()
  ax = Axes3D(fig)
 
  poly = PolyCollection(verts, facecolors=[cc('g')])
  ax.add_collection3d(poly, zs=slc, zdir='y')
  ax.set_xlim(xmin,xmax)
  ax.set_ylim(ymin,ymax)
  ax.set_zlim(zmin,zmax)
  plt.show()

waterfall(T, S, U.real, len(slc))

Now, that looks something like what I remember from class.