Saturday, April 1, 2017

Efficient machine learning

Suppose we have a set of data from which we wish to train a machine learning algorithm and the data keeps growing and growing. We'd like an efficient algorithm that will put more weight on the recent results and we would like it to be straight forward to keep up dated, even if we have hundreds of millions of rows of data. What can we do? Well, in this post I will suggest an algorithm that is very efficient at dealing with large data sets. It can learn from old data, but it doesn't need the old data to be stored.
Let's start with the basics. Suppose we have a set of T input vectors \(\underline{x}^t \in \mathbb{R}^n\) with \[1 \le t \le T\] and for each \(\underline{x}^t\) there is a resultant \(y^t \in \mathbb{R} \).
We want to find the optimal vector \(\underline{\theta}\) so that \( \underline{\theta } . \underline{x}^t \) is a good predictor of \( y^t \).
We choose the \(\underline{\theta}\) which minimises the cost function: \[ J( \underline{\theta} ) = \frac{1}{2T} \sum^{T}_{t=1} ( \underline{\theta} . \underline{x}^t -y^t)^2 \] We could expand the inner product \( \underline{\theta} . \underline{x}^t \), which would give us: \[ J( \underline{\theta} ) = \frac{1}{2T} \sum^{T}_{t=1} \big[ \sum_{j=1}^n ( \theta_j x^t_j ) -y^t \big]^2 \] When we have an optimal \( \underline{\theta} \), the partial derivatives of J will be zero: \[ 0 = \frac{\partial J}{ \partial \theta_i} = \frac{1}{T} \sum^{T}_{t=1} x_i^t \big[ \sum_{j=1}^n ( \theta_j x^t_j ) -y^t \big] \] We can change the order of the summation and do some rearranging to obtain: \[ 0 = \sum_{j=1}^n \big[ \frac{1}{T} \sum_{t=1}^T ( x^t_i x^t_j ) \theta_j \big] - \frac{1}{T} \sum_{t=1}^T ( y^t x_i^t) \hspace{24 mm} (eqn 1) \] Now, if we define the matrix \(\underline{A}\): \[ A_{ij} = \frac{1}{T} \sum_{t=1}^T ( x^t_i x^t_j ) \hspace{24 mm} (eqn 2) \] and we define vector \(\underline{B}\) \[ B_i = \frac{1}{T} \sum_{t=1}^T ( y^t x_i^t) \hspace{24 mm} (eqn 3) \] Plugging those into equation 1, we find: \[ 0 = \sum_{j=1}^n \big[ A_{ij} \theta_j \big] - B_i \] If we can invert the matrix \(\underline{A}\) then we can obtain \( \theta \) \[ \underline{\theta} = \underline{A}^{-1} \underline{B} \] If we look back at equation 2, we see that each vector \( \underline{x}^t\) effectively has the same weighting.
We could rewrite equations 2 and 3 as: \[ A_{ij} = \sum_{t=1}^T ( x^t_i x^t_j \omega_t) \hspace{24 mm} (eqn 4) \] and \[ B_i = \sum_{t=1}^T ( y^t x^t_i \omega_t) \hspace{24 mm} (eqn 5) \] where \[ \omega_t = \frac{1}{T} \hspace{24 mm} \forall t: 1 \le t \le T \] However, suppose we wanted more recent values of \( \underline{x}^t\) to have higher weighting. We could fix the ratio of consecutive weights: \[ \frac{\omega_{t+1}}{\omega_t} = 1 + \lambda \] with \( \lambda > 0\)
So we would have: \[ \omega_t = \frac{\omega_T} { (1+\lambda)^{T-t}} \] Suppose want to find the speed at which the weight drops down to half the weight of the most recently added entry, then we would seek \(\tau\) such that \[ \frac{\omega_T}{\omega_{T - \tau}} = 2 \] which implies: \[ (1 + \lambda)^{\tau} = 2 \] so: \[ \lambda = 2^{1/ \tau } -1 \hspace{24 mm} (eqn 6) \] So, if we want the most recent data to have double the weighting of a point 1,000 rows back, then we would set \(\tau = 1000 \) and use equation 6 to determine \( \lambda \).
The parameter \( \lambda \) determines how much bigger the weights of the recent data will have, when compared with older data. When choosing it, it may be helpful to first choose \( \tau\) ( which is something like a half-life ) and then use equation 6 to evaluate \( \lambda \)

If we wish to preserve the condition that the sum of the weights is 1, then we can do the geometric sum and after a little algebra we find that: \[ \omega_T = \frac{\lambda}{(1+\lambda)^T-1} \hspace{24 mm} (eqn 7) \]
For a given set of \(\underline{x}^t \) and \(y^t\) with \(1 \le t \le T \)
we can evaluate \(\underline{A} \) and \( \underline{B}\).
Since they were generated with T rows of data, we could label them \(\underline{A}^T \) and \( \underline{B}^T\).
Now suppose we have already evaluated \(\underline{A}^{(T-1)} \) and \( \underline{B}^{(T-1)}\)
with T-1 rows of data and we want to introduce one more,
then we find: \[ A^T_{ij} = \omega_T x^T_i x^T_j + ( 1 - \omega_T ) A^{(T-1)}_{ij} \hspace{24 mm} (eqn 8) \] and \[ B^T_i = \omega_T y^T x^T_i + ( 1 - \omega_T ) B^{(T-1)}_i \hspace{24 mm} (eqn 9) \] where \( \omega_T \) has been defined in equation 7.
So when we have evaluated \( \underline{A} \) and \( \underline{B} \) for a given set of data and then later want to include the contribution from a new \( \underline{x}^T \) and \( y^T \), we can amend the existing \( \underline{A} \) and \( \underline{B} \) using equations 8 and 9, without the need to retrieve all the past \( \underline{x}^t \) and \( y^t \).
As a result, when data ( \( \underline{x}^T \) and \( y^T \) ) comes in, we can use it to update \( \underline{A} \) and \( \underline{B} \) and then we can discard the \( \underline{x}^T \) and \( y^T \). They have made their contribution to \( \underline{A} \) and \( \underline{B} \) and we can continually update \( \underline{A} \) and \( \underline{B} \) without the need to look back at old values of \( \underline{x}^t \) and \( y^t \). When T is very large, say in the hundreds of millions, using this algorithm to continually update the machine learning results is rather efficient. Since we don't need to keep retrieving the old data. We just keep the \( \underline{A} \) and \( \underline{B} \) up to date.

Monday, February 27, 2017

Backward Propogation Algorithm in a Neural Network

This post was inspired by the Coursera course on Machine Learning. In particular by the lesson in week 5. In that lesson the backward propagation algorithm is presented but the students were asked to take the formulae on faith. A derivation was not given. To rectify that, here I present a derivation of the main equations.

Suppose we have an input vector \( \underline{x} \) and we wish to make a prediction for the resultant vector \(\underline{y}\). We start by defining our activation level zero to be: \[ \underline{a}^0 = \underline{x} \] We define \[ z^{k+1}_i = \sum_{j}\theta^{k+1}_{ij} a^k_j \hspace{24 mm} (eqn 1) \] and we use the sigmoid function S(z) to evaluate: \[ a^{k+1}_i= S(z^{k+1}_i) \hspace{24 mm} (eqn 2) \] When we have \(\underline{a}^0\), we can use equations 1 and 2 to evaluate \(\underline{a}^1\).
We then use equations 1 and 2 again to evaluate \(\underline{a}^2\) and so on up to \(\underline{a}^L\)

Our forecast for \(\underline{y}\) is activation level L: \(\underline{a}^L\)

Suppose we are given a training set of T input vectors \(\underline{x}^t\)
and a corresponding set of resultant vectors \(\underline{y}^t\) where \(1 \leq t \leq T \).
For each input vector \(\underline{x}^t\) we will have a corresponding activation level L: \(\underline{a}^{Lt}\)
We can measure the discrepancy between \(\underline{y}^t\) and \(\underline{a}^{Lt}\). We'll call this D. We could have: \[ D(\underline{y}^t , \underline{a}^{Lt} )= \parallel \underline{y}^t - \underline{a}^{Lt} \parallel^2 \] For the rest of this posting we won't refer again to the right hand side of that equation. So if the discrepancy function D were defined differently, then it wouldn't make any difference to the formulae below.

When we sum up all the discrepancies and we'll call the result the cost function: \[ J = \frac{1}{2T} \sum_t D(\underline{y}^t , \underline{a}^{Lt} ) \] We are interested in finding the \(\theta\)s which minimise this cost function and so gives us accurate predictions.
One approach would be to find the partial derivatives \[ \frac{\partial J}{\partial \theta^k_{ij}} \] for all i,j,k and then use a gradient descent method to minimise J. \[ \frac{\partial J}{\partial \theta^k_{ij}} = \frac{1}{2T} \sum_t \frac{\partial D(\underline{y}^t , \underline{a}^{Lt})}{\partial \theta^k_{ij}} \hspace{24 mm} (eqn 3) \] We are now going to focus on the term: \[ \frac{\partial D(\underline{y}^t , \underline{a}^{Lt})}{\partial \theta^k_{ij}} \] For simplicity we are going to drop the t superscript. So we will write: \[ \frac{\partial D(\underline{y} , \underline{a}^L)}{\partial \theta^k_{ij}} \] where the t is implied. Though in the end, when we want to evaluate the partial derivative of J, we will need to do the sum over t.

We define \[ \delta^k_i = \frac{\partial D}{\partial z^k_i} \hspace{24 mm} (eqn 4) \] and when we apply the chain rule to the right hand side we get \[ \delta^k_i = \sum_j \frac{\partial D}{\partial z^{k+1}_j} \frac {\partial z^{k+1}_j}{\partial z^k_i} \] we substitute in \( \delta^{k+1}_j\) and we find: \[ \delta^k_i = \sum_j \delta^{k+1}_j \frac {\partial z^{k+1}_j}{\partial z^k_i} \hspace{24 mm} (eqn 5) \] Now we want to evaluate the last term \[ \frac {\partial z^{k+1}_j}{\partial z^k_i} \] to do that we first combine equations 1 and 2 to obtain: \[ z^{k+1}_j = \sum_l \theta^k_{jl} S(z^k_l) \] and so \[ \frac {\partial z^{k+1}_j}{\partial z^k_i} = \sum_l \theta^k_{jl} S'(z^k_l) I_{li} \hspace{24 mm} (eqn 6) \] where \[ S'(z) = \frac {d S(z) }{dz} \] and \[ I_{lj} = \{ 1 \hspace{4 mm} when \hspace{4 mm} l = j, \hspace{4 mm}otherwise \hspace{4 mm}0 \} \] So, returning to equation 6, we find only one term in the sum survives, i.e. when l = j.
Hence eqn 6 becomes: \[ \frac {\partial z^{k+1}_j}{\partial z^k_i} = \theta^k_{ji} S'(z^k_i) \hspace{24 mm} (eqn 7) \] When we substitute that into equation 5 we find \[ \delta^k_i = \sum_j \delta^{k+1}_j \theta^k_{ji} S'(z^k_i) \hspace{24 mm} (eqn 8) \] It follows from equation 1 that: \[ \frac{\partial z^{k+1}_i}{\partial \theta^k_{jl}} = a^k_l I_{ij} \hspace{24 mm} (eqn 9) \] Using the chain rule we find: \[ \frac{\partial D}{\partial \theta^k_{ij}} = \sum_l \frac{\partial D}{\partial z^{k+1}_l} \frac{\partial z^{k+1}_l}{\partial \theta^k_{ij}} \] Substitute in equation 9 and we find only one element of the sum survives and we obtain: \[ \frac{\partial D}{\partial \theta^k_{ij}} = \delta^k_j a^{k+1}_i \hspace{24 mm} (eqn 10) \] So we can go forward (increasing k's) using equations 1 and 2 to evaluate \( a^k_i\)
and then using equation 8, go backwards ( decreasing k's) to evaluate \( \delta^k_j \).
We will then be able to work out all the partial derivatives of D with respect to \( \theta\).
Remember there is an implied superscript t in equation 10. And to work out the partial derivatives of J we will need to do the sum over all t's as shown in equation 3.
After that, finding the optimal \( \theta \) 's using gradient descent will be as easy as sliding down a hill.

Wednesday, January 25, 2017

Solitaire of sorts

Suppose you had a well shuffled standard deck of 52 cards. From the top of the pack you turned over each card one at a time. As you turned over the first card, you called out 'Ace', then 2 for the next, followed by 3, 4, 5, 6, 7, 8, 9, 10, Jack, Queen, King. After the 14th card you called out 'Ace' again followed by 2, 3, 4 and so on till you reached the end.
What would be the probability that none of the cards that were turned over had the value (rank) that you called out?

Before working it, lets consider a related, but simpler question. Suppose you had 3 cards containing values 1,2,3. After you shuffled them, what would be the probability that none of the 3 were in the same position as they were at the start? In mathematical speak you'd ask what is the probability of a derangement.
In this case we can look at all 3! ( i.e 6 ) permutations. And we see that 2 out of 6 are derangements. So the probability of a derangement is 1/3.
You might be tempted to say that for each card, the probability that it is not in its original position is 2/3 and so for the 3 cards the total probability of a drangement is that to the power of 3, i.e. 8/27.
However the flaw with that method is that the probabilities are not independent.

Returning now to that deck of 52 with 4 suits and 13 ranks. There is of course more than one way to work it out. One method would be to write a Monte Carlo simulation. I've done that and the code is below.
Here is a sample output after 100 million simulations:

Found 1623551 survivors out of 1.0E8
So the survival probability is estimated to be: 0.01623551
with a standard deviation of: 1.2638005465673727E-5
Time elapsed: 77.742 seconds


So the probability of surviving to the end without getting any cards right is approximately 1.62%

OK, OK, MC is useful, but not very satisfying. You might think that you could just write a program to work out all 52! permutations and then work out the answer. But alas 52! is a very big number. So if you want an answer before you die, then it might be a good idea to try a different approach.

Well there is another way... In fact I'm sure there any many other ways. I present here a method that I used. Suppose we have slots numbered 1 to 13 and for simplicity we'll number the cards 1 to 13. At the start we have 13 different ranks and each have 4 cards and there are 4 slots available for each of the ranks.
We could represent that as a table of Ranks:
Cards0  1  2  3  4  
Slots
000000
100000
200000
300000
4000013

Now suppose we pick one card, then we'll have 12 card ranks with 4 cards remaining
and we'll have 1 card rank that has 3 cards remaining, but still 4 slots.
So in our table of ranks we'll represent that as:

Cards0  1  2  3  4  
Slots
000000
100000
200000
300000
4000112

But we'll have to put the card in one of the available slots. Of the 52 slots 48 are allowed.
When we do that we'll have
 1 rank with 4 slots and 3 cards remaining
 1 rank with 3 slots and 4 cards remaining
12 ranks with 4 slots and 4 cards remaining

In our table we represent that as:

Cards0  1  2  3  4  
Slots
000000
100000
200000
300001
4000112

Using such a table with the probabilities of the transitions, we can write some code that recursively solves the problem.

This is what the code looks like:

import java.io.PrintWriter;

// check if the back slash char is causing problems

public class Calculator {
       
    public static void main ( String[] args){
       
        long startTime = System.currentTimeMillis();

        System.out.println("Starting...");
       
        Calculator calc = new Calculator();

        calc.calc(); // calc() is the main calculator, calcDebug() is for debugging
       
        long   endTime     = System.currentTimeMillis();
        double elapsedSec  = ((double) endTime - (double) startTime) * 0.001d;
       
        System.out.println("\nElapsed time: " + Double.toString(elapsedSec) + " sec.");       
        System.out.println("Finished.");
    }
    ///////////////////////////////////////////////////////////////////
    public void calc(){
        System.out.println("Doing calc.");

        int suits = 4;
        int ranks = 9; // Total number of cards will be: suits * ranks
       
        State s = new State(suits, ranks);
        double survivalProb = s.getSurvivalProb();
       
        String key = s.getKey();
        String resStr = "\nFor key: " + key + " found prob to be: "  + String.format("%1.14f", survivalProb);
        System.out.println(resStr);
       
        String pathToFile = "C:/temp/res_" + key + ".txt";
        writeToFile(pathToFile, resStr );

    }
    ///////////////////////////////////////////////////////////////
    public void calcDebug(){
        System.out.println("Doing calc2.");
        //           "0000000000111111111122222"
        //           "0123456789012345678901234"
        String key = "1001010000000000100000000";

        State s = new State(key, 0);
        double survivalProb = s.getSurvivalProb();
       
        String resStr = "For key: " + key + " found prob to be: "  + String.format("%1.14f", survivalProb);
        System.out.println(resStr);
       
        String pathToFile = "C:/temp/res_" + key + ".txt";
        writeToFile(pathToFile, resStr );
    }
    ///////////////////////////////////////////////////////////////
    public void writeToFile( String pathToFile, String str){
        try{
            PrintWriter writer = new PrintWriter(pathToFile, "UTF-8");
            writer.println(str);
            writer.close();
        } catch(Exception e){
            System.out.println("Caught exception: " + e.getMessage());
        }
    }
    /*  Prob(1,4)  = 0.375
     *  Prob(2,2)  = 0.166667
     *  Prob(4,4)  = 0.011869
     *  Prob(4,8)  = 0.014967   ( takes 11.6 secs )
     *  Prob(4,13) = 0.016232
     */
   
}

///////////////////////////////////////////////////////////////////////////////        ///////////////////////////////////////////////////////////////////////////////        ///////////////////////////////////////////////////////////////////////////////        


import java.util.Random;
import java.util.TreeMap;

// prob survival = sum ( numCardsLikeThis * numSafeLocations / TotalslotLocations
//                         * ProbSurvival( next))

// The key is a string 25 chars long,

public class State {
   private static final boolean            m_debug               = false; 
   private static final long               m_seed                = 1;
   private static final int                m_stepsBetweenCaching = 0; //

   private static int                      m_numSuits;
   private static TreeMap  m_map;
   private static Random                   m_rnd;

   private        String                   m_key;  
   private        double                   m_survivalProb;  
   private        int                      m_totalSlots;
   private        int                      m_depth;
  

   ///////////////////////////////////////////////////////
   State ( int suits, int ranks){
      
       m_depth    = 0;
       m_numSuits = suits;
       String      zeros = new String(new char[(suits+1) * (suits+1)]).replace("\0", "0");
       String      key   = adjustKey(suits, suits, zeros, ranks);
       initialize( key );    
      
       if ( m_debug) printKey  ( key );
   }
   ///////////////////////////////////////////////////////
   State( String key, int depth){
       m_depth = depth;
       initialize(key);
   }
   ///////////////////////////////////////////
   private void initialize(String key){
      
       if ( m_rnd == null)
            m_rnd =  new Random(m_seed); // formerly had: ThreadLocalRandom.current();

      
       if ( m_map == null)
            m_map =  new TreeMap();
      
       if( m_numSuits == 0){
           double sqrt = Math.sqrt((double) key.length());
           m_numSuits  = (int) Math.round(sqrt - 0.5f) -1;
           if ( m_debug)  {
              System.out.println(   "Have key of length: " + Integer.toString(key.length())
                                     + " and num suits: "     + Integer.toString(m_numSuits)   );
           }
       }  
      
       m_key          = key;
      
       m_survivalProb = -1;   // This indicates that it has not been calculated.      
       m_totalSlots   = calcTotalSlots();   
      
       if ( m_debug)  printKey(key);
   }
   ///////////////////////////////////////////////////////
   private int calcTotalSlots(){
       int sum = 0;
       for ( int slot = 0 ; slot <= m_numSuits; slot++){
           sum += getNumSlots(slot);
       }
       if ( m_debug)  System.out.println("Key: " + m_key + ", total number of slots found is: " + Integer.toString(sum));
       return sum;
   }
   ///////////////////////////////////////////////////////
   String getKey() { return m_key; }
   ///////////////////////////////////////////////////////
   double calcSurvivalProb(){
      
       if ( m_totalSlots == 0)
           return 1.0; // if we reach the end and there are no slots left, then we have survived.
      
       double prob = 0;
      
       for     ( int slot = 0; slot <= m_numSuits; slot++){
           for ( int card = 1; card <= m_numSuits; card++){
              
                  double probForCard = calcProbForCard(slot, card);
                  prob += probForCard;
                 
                  if ( m_depth < 3) {
                      System.out.println(  "Depth: "       + Integer.toString(m_depth)
                                         + ", slot: "      + Integer.toString(slot)
                                         + ", card: "      + Integer.toString(card)
                                         + ", num slots: " + Integer.toString(m_totalSlots)
                                         + ", prob: "      + Double.toString (probForCard)   );
                  }
           }
       }
      
       if ( m_debug) {
          printKey(m_key);
          System.out.println("Found prob to be: " + String.format("%1.14f",prob) + "\n");
       }      
       return prob;
   }
   ///////////////////////////////////////////////////////
   public double calcProbForCard(int slot, int card){
      
       if ( card == 0)
           return 0;

       int count = getCount(slot, card);
      
       if ( count == 0)
           return 0;
      
       int numCards = count * card;
      
       double probOfCardChoice = (double) numCards / (double) m_totalSlots;
      
       double probOfSlot = 0.0;
      
       for     ( int destSlot = 1; destSlot <= m_numSuits; destSlot++){
           for ( int destCard = 0; destCard <= m_numSuits; destCard++){
              
                  int destCount = getCount(destSlot, destCard);
                  if ( (slot == destSlot) && (card == destCard))
                      destCount--;  // a card cannot go into its own slot.
                 
                  int availableSlots = destSlot * destCount;

                  if ( availableSlots > 0){
                      double probOfContinuedSurvival = getProbOfContinuedSurvival(slot, card, destSlot, destCard);
                     
                      if ( probOfContinuedSurvival > 0) {          
                          probOfSlot += probOfContinuedSurvival * probOfCardChoice *(double) availableSlots /(double) m_totalSlots;
                          if ( probOfSlot > 1.0000001) { // We don't have >= 1.0 to allow a small rounding error
                              printKey(m_key);
                              System.out.println("ERROR: Prob is too big: " + Double.toString(probOfSlot));
                          }
                      }
                  }
           }
       }
       if ( m_debug) System.out.println("Found prob of slot to be: " + Double.toString(probOfSlot));
       return probOfSlot;
   }
   ///////////////////////////////////////////////////////
   double getProbOfContinuedSurvival(int sourceSlot, int sourceCard, int destSlot, int destCard){
      
          String keyForChosenCard   = adjustKeyForTakenCard (sourceSlot, sourceCard, m_key);
          String keyForContinuation = adjustKeyForCardInSlot(destSlot,   destCard,   keyForChosenCard);   
         
          // check if state obj is in map
          Double contProbDouble = m_map.get(keyForContinuation);
          double contProb;
         
          if ( contProbDouble == null ){
              State contState = new State(keyForContinuation, m_depth + 1);
              contProb        = contState.getSurvivalProb();
             
              if (    ( m_stepsBetweenCaching == 0 )
                   || (m_rnd.nextInt(m_stepsBetweenCaching) == 0 )) { // we'll only insert one in n key,value pairs to the map.
                 contProbDouble  = new Double( contProb);
                 m_map.put(keyForContinuation, contProbDouble); // we store it for later use
             
                   if ( m_map.size() % 10000 == 0)
                    System.out.println(   "Map elm: "+ Integer.toString(m_map.size())
                                       + ", key: "  + keyForContinuation
                                      + ", prob: " + Double.toString(contProbDouble));
             
                if ( m_debug)
                    System.out.println(  "Added new key to map: " + keyForContinuation
                                          + " that is item: " + Integer.toString(m_map.size()) );
              }
          } else {
              contProb = contProbDouble.doubleValue();
          }

          if ( m_debug) {
             System.out.println("Key: " + keyForContinuation + ", found cont prob: " + Double.toString(contProb));
             printKey(keyForContinuation);
          }
          return contProb;
   }
   ///////////////////////////////////////////////////////
   double getSurvivalProb() {
      
       if ( m_survivalProb == -1) // i.e. not yet set
            m_survivalProb = calcSurvivalProb();
             
       return m_survivalProb;
   }
   //////////////////////////////////////////////////////
   int getCount(int slot, int card){
             
       return ( getCount(slot, card, m_key));            
   }
   //////////////////////////////////////////////////////
   static int getCount(int slot, int card, String key){
      
       int    index = getIndex(slot, card);
       // System.out.println("Key: " + key + ", index: " + Integer.toString(index));
       String c     = key.substring(index, index + 1);
      
       return charToInt(c);
       // return ( Integer.parseInt(c, 16));            
   }
   //////////////////////////////////////////////////////
   public int getNumSlots(int slot){
      
       if ( slot == 0)
           return 0;
      
       int sum = 0;

       for ( int card = 0; card <= m_numSuits; card++)
           sum += slot * getCount(slot, card);
      
       return sum;      
   }
   //////////////////////////////////////////////////////
   int getTotalSlots(){
       return m_totalSlots;
   }
   //////////////////////////////////////////////////////
   public static void printKey(String key){
   
       if ( key == null){
           System.out.println("Key is null.");
           return;
       }

       System.out.println("key: " + key);

       String header = "   Card ";
      
       for( int j = 0; j <= m_numSuits; j++){
           header += Integer.toString(j) + " ";
       }
      
       System.out.println(header);
      
       for ( int slot = 0; slot <= m_numSuits; slot++){
           String str = "Slot " + Integer.toString(slot) + ": ";
          
           for ( int card = 0; card <= m_numSuits; card++){
               int index = getIndex(slot, card);
               str += key.substring(index, index+1) + " ";
           }
           System.out.println(str);
       }   
   }
   //////////////////////////////////////////////////////
   public static int getIndex(int slot, int card){
       int index = slot * (m_numSuits + 1) + card;
       /*
       System.out.println(    "For slot: "   + Integer.toString(slot)
                           + " and card: "   + Integer.toString(card)
                           + " have index: " + Integer.toString(index));
       */
       return (index);
   }
   //////////////////////////////////////////////////////
   public static String adjustKeyForTakenCard(int slot, int card, String key){
      
       if( card == 0)
           return null; // no cards to give

       String decrementedCurrent = adjustKey(slot, card   , key,                -1);
       return                      adjustKey(slot, card -1, decrementedCurrent, +1);
   }
   //////////////////////////////////////////////////////
   public static String adjustKeyForCardInSlot(int slot, int card, String key){
       if ( slot == 0)
           return null; // have no slot
      
       String adjKey  = adjustKey(slot    , card, key   , -1);
       adjKey         = adjustKey(slot - 1, card, adjKey, +1);          
             
       for( int i = m_numSuits; i > 1; i--){
          
           int countWithZeroCards = getCount(i,0, adjKey );
           if( countWithZeroCards > 0){
               adjKey = adjustKey(i,0, adjKey, -countWithZeroCards);
               adjKey = adjustKey(1,0, adjKey, countWithZeroCards * i);
           }

           int countWithZeroSuits = getCount(0,i, adjKey );
           if( countWithZeroSuits > 0){
               adjKey = adjustKey(0, i, adjKey, -countWithZeroSuits);
               adjKey = adjustKey(0, 1, adjKey,  countWithZeroSuits * i);
           }
       }
      
       int zeroZeroCount         = getCount(0,0, adjKey);
      
       return adjustKey( 0,0, adjKey, - zeroZeroCount);
   }
   //////////////////////////////////////////////////////
   public static String adjustKey(int slot, int card, String key, int adj){
      
       if ( key == null)
           return null;
      
       int count = getCount(slot, card, key);
      
       if ( 0 > count + adj){
           System.out.println("Cannot adjust below zero for slot" );
           return null;
       /*      
       } else if ( count + adj > 15){
           System.out.println("Cannot adjust above 1 char hex limit.");
           return null;       
        */  
       } else {
           int    index   = getIndex(slot, card);
          
           String start   = key.substring(0, index );
           String end     = key.substring(index + 1);
          
           String current = intToChar(count+ adj);
           return ( start + current + end);          
       }          
   }
   //////////////////////////////////////////////////////
   public static String intToChar(int i ){
       char c = (char)(i+48);
      
       return "" + c;
   }
   //////////////////////////////////////////////////////
   public static int charToInt(String s){
       char c = s.charAt(0);
       int  i = (int)c - 48;
      
       if ( m_debug) System.out.println("Char: " + s + " is deemed to be: " + Integer.toString(i));
      
       return (int) i;
   }
   /////////////////////////////////////////////////////
}
 

///////////////////////////////////////////////////////////////////////////////       





And here's the Java code for the Monte Carlo:

import java.util.Random;

public class RanksAndSuitsSurvival {

    private final long   m_simsPerBatch = 1_000_000L;
    private final int    m_numBatches   = 100;

    private final int    m_numSuits     = 4;   // should be  4
    private final int    m_numRanks     = 13;  // should be 13
   
    private       int    m_numCards;
    private       int[]  m_hand;
   
    private       Random m_rnd;
    private final int    m_rndSeed      = 1;
    ///////////////////////////////////////////////////////////////////////////////
    public static void main ( String[] args){
               
        System.out.println("Started...");
       
        RanksAndSuitsSurvival  rass = new RanksAndSuitsSurvival();
        rass.calc();
       
        System.out.println("\nFinished");   
    }
    ///////////////////////////////////////////////////////////////////////////////
    public RanksAndSuitsSurvival(){
       
        m_rnd      = new Random(m_rndSeed);
       
        m_numCards = m_numSuits * m_numRanks;
        m_hand     = new int[m_numCards];
       
        for ( int i = 0; i < m_numCards; i++)
            m_hand[i] = i % m_numRanks;
    }
    ///////////////////////////////////////////////////////////////////////////////
    public void calc(){
       
        long startTime = System.currentTimeMillis();
               
        long survivors = 0;
       
        // The only reason for the nested 'for' loops rather than a single 'for' loop
        // is because we wanted to print out the progress intermittently and we didn't
        // want to introduce another 'if' inside the main 'for' loop for reasons of speed.
        for ( int batch = 1; batch <= m_numBatches; batch++ ) {
       
            for ( long counter = 0; counter < m_simsPerBatch; counter ++){
       
                if (survived())  
                    survivors++;
            }
           
            System.out.println("Have now completed batch: " + Integer.toString(batch) + ", after "
                               + Double.toString((double)( System.currentTimeMillis() - startTime) * 0.001) + " seconds");
        }
   
        double numSims = (double) m_numBatches * (double) m_simsPerBatch;
        double prob    = (double) survivors              / numSims;
        double stdDev  = Math.sqrt( (1.0 - prob ) * prob / numSims);

        long   endTime = System.currentTimeMillis();
       
        printoutResults(survivors, numSims, prob, stdDev, endTime - startTime);
    }
    /////////////////////////////////////////////////////////////////////////////////
    public void printoutResults(long survivors, double numSims, double prob, double stdDev, long elapsedTimeMS){
   
        System.out.println("\nFound " + Long.toString(survivors) + " survivors out of " + Double.toString(numSims) );
        System.out.println("So the survival probability is estimated to be: "           + Double.toString(prob     ) );
        System.out.println("with a standard deviation of:                   "           + Double.toString(stdDev   ) );
           
        System.out.println("Time elapsed: " + Double.toString((double)( elapsedTimeMS) * 0.001d) + " seconds");
       
    }
    ///////////////////////////////////////////////////////////////////////////////
    public boolean survived(){
        
         shuffleHand();     // will shuffle m_hand
         return checkHand();
    }
    ///////////////////////////////////////////////////////////////////////////////   
    // Implementing Fisher–Yates shuffle
    public void shuffleHand(){         
        for (int i = m_hand.length - 1; i > 0; i--)  {
          int index     = m_rnd.nextInt(i + 1);
          // Simple swap
          int a         = m_hand[index];
          m_hand[index] = m_hand[i];
          m_hand[i]     = a;
        }
    }
    ////////////////////////////////////////////////////////////////////
    // Returns true when the hand 'survived'.
    public boolean checkHand(){      
                   
           for (int i = 0; i < m_hand.length; i++){
              
               if ( ((m_hand[i] - i) % m_numRanks) == 0 )
                      return false; // did not survive
           }
              
           return true;  // survived              
    }     
    ///////////////////////////////////////////////////////////////////////////////   
}