Introduction
As I’ve explained previously, I’m gradually coming around to the idea of using Java for the development of MCMC codes, and I’m starting to build up a collection of simple examples for getting started. One of the advantages of Java is that it includes a standard cross-platform GUI library. This might not seem like the most important requirement for MCMC, but can actually be very handy in several contexts, particularly for monitoring convergence. One obvious context is that of image analysis, where it can be useful to monitor image reconstructions as the sampler is running. In this post I’ll show three very small simple Java classes which together provide an application for running a Gibbs sampler on a (non-stationary, unconditioned) Gaussian Markov random field.
The model is essentially that the distribution of each pixel is defined intrinsically, dependent only on its four nearest neighbours on a rectangular lattice, and here the distribution will be Gaussian with mean equal to the sample mean of the four neighbouring pixels and a fixed (unit) variance. On its own this isn’t especially useful, but it is a key component of many image analysis applications.
A simple Java implementation
We will start with the class MrfApp containing the main method for the application:
MrfApp.java
import java.io.*; class MrfApp { public static void main(String[] arg) throws IOException { Mrf mrf; System.out.println("started program"); mrf=new Mrf(800,600); System.out.println("created mrf object"); mrf.update(1000); System.out.println("done updates"); mrf.saveImage("mrf.png"); System.out.println("finished program"); mrf.frame.dispose(); System.exit(0); } }
Hopefully this code is largely self-explanatory, but relies on a class called Mrf which contains all of the logic associated with the GMRF.
Mrf.java
import java.io.*; import java.util.*; import java.awt.image.*; import javax.swing.*; import javax.imageio.ImageIO; class Mrf { int n,m; double[][] cells; Random rng; BufferedImage bi; WritableRaster wr; JFrame frame; ImagePanel ip; Mrf(int n_arg,int m_arg) { n=n_arg; m=m_arg; cells=new double[n][m]; rng=new Random(); bi=new BufferedImage(n,m,BufferedImage.TYPE_BYTE_GRAY); wr=bi.getRaster(); frame=new JFrame("MRF"); frame.setSize(n,m); frame.add(new ImagePanel(bi)); frame.setVisible(true); } public void saveImage(String filename) throws IOException { ImageIO.write(bi,"PNG",new File(filename)); } public void updateImage() { double mx=-1e+100; double mn=1e+100; for (int i=0;i<n;i++) { for (int j=0;j<m;j++) { if (cells[i][j]>mx) { mx=cells[i][j]; } if (cells[i][j]<mn) { mn=cells[i][j]; } } } for (int i=0;i<n;i++) { for (int j=0;j<m;j++) { int level=(int) (255*(cells[i][j]-mn)/(mx-mn)); wr.setSample(i,j,0,level); } } frame.repaint(); } public void update(int num) { for (int i=0;i<num;i++) { updateOnce(); } } private void updateOnce() { double mean; for (int i=0;i<n;i++) { for (int j=0;j<m;j++) { if (i==0) { if (j==0) { mean=0.5*(cells[0][1]+cells[1][0]); } else if (j==m-1) { mean=0.5*(cells[0][j-1]+cells[1][j]); } else { mean=(cells[0][j-1]+cells[0][j+1]+cells[1][j])/3.0; } } else if (i==n-1) { if (j==0) { mean=0.5*(cells[i][1]+cells[i-1][0]); } else if (j==m-1) { mean=0.5*(cells[i][j-1]+cells[i-1][j]); } else { mean=(cells[i][j-1]+cells[i][j+1]+cells[i-1][j])/3.0; } } else if (j==0) { mean=(cells[i-1][0]+cells[i+1][0]+cells[i][1])/3.0; } else if (j==m-1) { mean=(cells[i-1][j]+cells[i+1][j]+cells[i][j-1])/3.0; } else { mean=0.25*(cells[i][j-1]+cells[i][j+1]+cells[i+1][j] +cells[i-1][j]); } cells[i][j]=mean+rng.nextGaussian(); } } updateImage(); } }
This class contains a few simple methods for creating and updating the GMRF, and also for maintaining and updating a graphical view of the GMRF as the sampler is running. The Gibbs sampler update itself is encoded in the final method, updateOnce, and most of the code is to deal with edge and corner cases (in the literal rather than metaphorical sense!). This is called repeatedly by the method update for the required number of iterations. At the end of each iteration, the method updateOnce triggers updateImage which updates the image associated GMRF. The GMRF itself is stored in a 2-dimensional array of doubles, but an image pixel typically consists of a grayscale value represented by an unsigned byte – that is, an integer from 0 to 255. So updateImage scans through the GMRF to find the maximum and minimum values and then maps the GMRF values onto the 0 to 255 scale. The image itself is set up by the constructor method, Mrf. This class relies on an additional class called ImagePanel, which is a simple GUI panel for displaying images:
ImagePanel.java
import java.awt.*; import java.awt.image.*; import javax.swing.*; class ImagePanel extends JPanel { protected BufferedImage image; public ImagePanel(BufferedImage image) { this.image=image; Dimension dim=new Dimension(image.getWidth(),image.getHeight()); setPreferredSize(dim); setMinimumSize(dim); revalidate(); repaint(); } public void paintComponent(Graphics g) { g.drawImage(image,0,0,this); } }
This completes the application, which can be compiled and run from the command line with
javac *.java java MrfApp
This should compile the code and run the application, which will show a GMRF updating for 1000 iterations. When the 1000 iterations are complete, the application writes the final image to a file and then quits.
Using Parallel COLT
The above classes are very convenient, as they should work with any standard Java installation. However, in more complex scenarios, it is likely that a math library such as Parallel COLT will be required. In this case it will make sense to make use of features in the COLT library, such as random number generators and 2d matrix objects. We can adapt the above application by replacing the MrfApp and Mrf classes with the following versions (the ImagePanel class remains unchanged):
MrfApp.java
import java.io.*; import cern.jet.random.tdouble.engine.*; class MrfApp { public static void main(String[] arg) throws IOException { Mrf mrf; int seed=1234; System.out.println("started program"); DoubleRandomEngine rngEngine=new DoubleMersenneTwister(seed); mrf=new Mrf(800,600,rngEngine); System.out.println("created mrf object"); mrf.update(1000); System.out.println("done updates"); mrf.saveImage("mrf.png"); System.out.println("finished program"); mrf.frame.dispose(); System.exit(0); } }
Mrf.java
import java.io.*; import java.util.*; import java.awt.image.*; import javax.swing.*; import javax.imageio.ImageIO; import cern.jet.random.tdouble.*; import cern.jet.random.tdouble.engine.*; import cern.colt.matrix.tdouble.impl.*; class Mrf { int n,m; DenseDoubleMatrix2D cells; DoubleRandomEngine rng; Normal rngN; BufferedImage bi; WritableRaster wr; JFrame frame; ImagePanel ip; Mrf(int n_arg,int m_arg,DoubleRandomEngine rng) { n=n_arg; m=m_arg; cells=new DenseDoubleMatrix2D(n,m); this.rng=rng; rngN=new Normal(0.0,1.0,rng); bi=new BufferedImage(n,m,BufferedImage.TYPE_BYTE_GRAY); wr=bi.getRaster(); frame=new JFrame("MRF"); frame.setSize(n,m); frame.add(new ImagePanel(bi)); frame.setVisible(true); } public void saveImage(String filename) throws IOException { ImageIO.write(bi,"PNG",new File(filename)); } public void updateImage() { double mx=-1e+100; double mn=1e+100; for (int i=0;i<n;i++) { for (int j=0;j<m;j++) { if (cells.getQuick(i,j)>mx) { mx=cells.getQuick(i,j); } if (cells.getQuick(i,j)<mn) { mn=cells.getQuick(i,j); } } } for (int i=0;i<n;i++) { for (int j=0;j<m;j++) { int level=(int) (255*(cells.getQuick(i,j)-mn)/(mx-mn)); wr.setSample(i,j,0,level); } } frame.repaint(); } public void update(int num) { for (int i=0;i<num;i++) { updateOnce(); } } private void updateOnce() { double mean; for (int i=0;i<n;i++) { for (int j=0;j<m;j++) { if (i==0) { if (j==0) { mean=0.5*(cells.getQuick(0,1)+cells.getQuick(1,0)); } else if (j==m-1) { mean=0.5*(cells.getQuick(0,j-1)+cells.getQuick(1,j)); } else { mean=(cells.getQuick(0,j-1)+cells.getQuick(0,j+1)+cells.getQuick(1,j))/3.0; } } else if (i==n-1) { if (j==0) { mean=0.5*(cells.getQuick(i,1)+cells.getQuick(i-1,0)); } else if (j==m-1) { mean=0.5*(cells.getQuick(i,j-1)+cells.getQuick(i-1,j)); } else { mean=(cells.getQuick(i,j-1)+cells.getQuick(i,j+1)+cells.getQuick(i-1,j))/3.0; } } else if (j==0) { mean=(cells.getQuick(i-1,0)+cells.getQuick(i+1,0)+cells.getQuick(i,1))/3.0; } else if (j==m-1) { mean=(cells.getQuick(i-1,j)+cells.getQuick(i+1,j)+cells.getQuick(i,j-1))/3.0; } else { mean=0.25*(cells.getQuick(i,j-1)+cells.getQuick(i,j+1)+cells.getQuick(i+1,j) +cells.getQuick(i-1,j)); } cells.setQuick(i,j,mean+rngN.nextDouble()); } } updateImage(); } }
Again, the code should be reasonably self explanatory, and will compile and run in the same way provided that Parallel COLT is installed and in your classpath. This version runs approximately twice as fast as the previous version on all of the machines I’ve tried it on.
Reference
I have found the following book very useful for understanding how to work with images in Java:
Hunt, K.A. (2010) The Art of Image Processing with Java, A K Peters/CRC Press.
Image may be NSFW.
Clik here to view.
Clik here to view.
