// --*- C++ -*------x---------------------------------------------------------
#ifndef __BIRTHDAY_PROB_ALL__
#define __BIRTHDAY_PROB_ALL__

#include <Vec.h>
#include <Random.h>
#include <debug.h>
#include <MultiValArray.h>

#define UNDEFINED_PROB -1.0

class BirthdayProbAll {

 public:
  
  typedef int int_type;
  typedef Vec<double>::size_type size_type;
  typedef MultiValArray::index_array index_array; 

  enum { OUTCOME_MAX_DEFAULT = 40 , USE_MATRICES_MODE = 1, DIM=6 };

 private:

  mutable MultiValArray matrices;
  int_type countDim;
  int_type countMax;
  int_type outcomeMax;
  bool exactMode;

 public:

  BirthdayProbAll(int_type _outcomeMax, const index_array& dimensions) : matrices(dimensions), outcomeMax(_outcomeMax),exactMode(true) {
    ASSERT(dimensions.size() > 1);
    countDim = dimensions.size()-1; // number of counts that are kept track of
    countMax = dimensions[0]; // maximum number of counts (0 based)
    matrices.fill(UNDEFINED_PROB );
  }

  /*  BirthdayProbAll(int_type _outcomeMax, int_type _countDim, int_type _countMax) : countDim(_countDim), countMax(_countMax),
										  outcomeMax(_outcomeMax) {
    ASSERT(matrices.getDimensionCount() == DIM);
    ASSERT(static_cast<int_type>(matrices.getDimension(0)) == outcomeMax);
    matrices.fill(UNDEFINED_PROB );
  }
  */

  virtual ~BirthdayProbAll() { }

  /** For k given possibilities (like 365 days), n trials (like n people),
   * what is the probability to observe only m OR LESS different outcomes (like m different birthdays)
   * @param epsilon how accurate
   */
  static double computeP(int_type k, int_type n, int_type m, int_type iterMax = 10000, 
			 int_type pseudoCount1 = 1, int_type pseudoCount2 = 1) {
    ASSERT(k > 0);
    ASSERT(n > 0);
    ASSERT(m <= k); // number successes must be smaller or equal than number of outcomes (for now)
    Random& rnd = Random::getInstance();
    int_type successes = 0; // pseudocount
    Vec<int> outcomes(k,0);
    for (int_type i = 0; i < iterMax; ++i) {
      for (int_type j = 0; j < k; ++j) { // loop over n trials
	outcomes[j] = 0;
      }
      int_type nonZeros = 0;
      for (int_type j = 0; j < n; ++j) { // loop over n trials
	int index = rnd.getRand(k);
	if (outcomes[index] == 0) {
	  ++nonZeros;
	} else {
	  ASSERT(j > 0);
	}
	outcomes[index] = outcomes[index] + 1;
      }
      ASSERT(nonZeros > 0);
      if (nonZeros <= m) {
	++successes;
      }
    }
    return static_cast<double>(successes + pseudoCount1) / (iterMax + pseudoCount2); // pseudocount: apriori: P = 1; nonZeros + 2 would mean P = 0.5 a priori
  }

  /** For k given possibilities (like 365 days), n trials (like n people),
   * what is the probability to observe only m1 different outcomes 1 times
   * m2 different outcomes 2 OR MORE times (like m different birthdays)
   * @param epsilon how accurate
   */
  virtual double exactProb(int_type k, int_type n, int_type m1, int_type m2, int_type m3, int_type m4) const {
    // cout << "Called exactProb with " << k << " " << n << " " << m1 << " " << m2 << " " << m3 << " " << m4 << endl; 
    PRECOND (k >= 1);
    PRECOND(n >= 1);
    PRECOND (m1 >= 0);
    PRECOND (m2 >= 0);
    PRECOND (m3 >= 0);
    PRECOND (m4 >= 0);
    //    PRECOND( n <= k); // for now restricted to less trials then outcomes
    //    if (k > static_cast<int_type>(matrices.getDimension(0))) {
    //     cout << "# Warning: number of outcomes ( " << k << " ) exceeds internal data structure size of " << matrices.getDimension(0) << " . Cannot compute precise p-value." << endl; 
    //    return 1.0;
    // }
    double result = UNDEFINED_PROB;
    MultiValArray::index_array indices(DIM);
    indices[0] = k; // NOT ANYMORE -1 because internal matrices are 0-based, external loops are 1-based
    indices[1] = n;
    indices[2] = m1;
    indices[3] = m2;
    indices[4] = m3;
    indices[5] = m4;
    int_type nCounted = m1 + (2 * m2) + (3 * m3) + (4*m4); // at most n; m4: counts of cases that occurred 4 or more times; m1 to m3: precice
    if (USE_MATRICES_MODE) {
      // try lookup table:
      //      if (matrices[k-1].size() == 0) {
      // matrices[k-1] = Vec<Vec<double> > (k, Vec<double>(k, UNDEFINED_PROB));
      // }
      result = matrices.get(indices); 
    }
    if (result != UNDEFINED_PROB) {
      ASSERT(result >= 0.0 && result <= 1.0);
      return result;
    }
    if ((n == 0) && (nCounted == 0)) {
      return 1.0;      
    }
    else if ((nCounted > n) || (nCounted == 0)) {
      return 0.0;
    } else if ((n == 1) && (m1 == 1) && (m2 == 0) && (m3 == 0) && (m4 == 0)) {
      result = 1.0;
    } else {
      result = 0.0;
      if ((m1 > 0) && (k > ((m1-1)+m2+m3+m4))) {
 	result += (exactProb(k,n-1,m1-1,m2,m3,m4) * (k-((m1-1)+m2+m3+m4))) / k ;
      }
      if (m2 > 0) {
      	result += (exactProb(k,n-1,m1+1,m2-1,m3,m4) * (m1 + 1)) / k;
      }
      if (m3 > 0) {
	result += (exactProb(k,n-1,m1,m2+1,m3-1,m4) * (m2 + 1)) /k;
      }
      if (m4 > 0) {
        result += (exactProb(k,n-1,m1,m2,m3+1,m4-1) * (m3 + 1)) / k;
      }
      if (m4 > 0)  {
        result += (exactProb(k,n-1,m1,m2,m3,m4-1) * m4) / k;
      }

    }
    ASSERT(result < 1.1); // rounding errors are possible but must not exceed 0.1
    if (result > 1.0) {
      result = 1.0;
    }
    ASSERT(result > -0.1); // rounding errors are possible but must not exceed 0.1
    if (result < 0.0) {
      result = 0.0;
    }
    if (USE_MATRICES_MODE) {
      matrices.set(indices, result); // [k-1][n-1][m-1] = result; // store result
    }
    // cout << "Finished exactProb with " << k << " " << n << " " << m1 << " " << m2 << " " << m3 << " " << m4 << " " << nCounted 
    // << " : " << result << endl; 
    POSTCOND(result >= 0.0 && (result <= 1.0)); 
    return result;
  }

  /** For k given possibilities (like 365 days), n trials (like n people),
   * what is the probability to observe only m1 different outcomes 1 times
   * m2 different outcomes 2 OR MORE times (like m different birthdays)
   * @param epsilon how accurate
   * counts[0] : number of cases that occurred never
   * counts[1] : m1 number of cases that occurred once
  */
  virtual double exactProb(const index_array& indices) {
    PRECOND(static_cast<int_type>(indices.size()) == (countDim + 1));
    cout << "Starting exactProb with Indices: " << indices << endl;
    double result = UNDEFINED_PROB;
    int_type nCounted = 0; 
    int_type n = indices[0];
    for (size_type i = 1; i  < indices.size(); ++i) {
      nCounted += i * indices[i];
    }
    if (exactMode) {
      ERROR_IF(nCounted != n, "In exact mode, number of cases cannot be exceed specified value.");
    }
    if (USE_MATRICES_MODE) {
      // try lookup table:
      result = matrices.get(indices); 
    }
    if (result != UNDEFINED_PROB) {
      ASSERT(result >= 0.0 && result <= 1.0);
      cout << "Returning saved result: " << result << endl;
      return result;
    }
    if (nCounted > n) {
      DEBUG_MSG("nCounted is greater n!");
      return 0.0;
    } else if (nCounted == 0) {
      DEBUG_MSG("nCounted is zero!");
      if (n == 0) {
	return 1.0;
      }  
      return 0.0;
    } else if ((n == 1) && (nCounted == 1)) {
      cout << "Basic case: " << indices << endl;
      result = 1.0;
    } else {
      cout << "Full case " << indices << endl;
      double p0 = 1.0;
      result = 0.0;
      index_array newCounts = indices;
      int_type k=outcomeMax;
      for (size_type i = 2; i < indices.size(); ++i) {
	newCounts = indices;
	newCounts[0] = indices[0] - 1; // n-1
	double p = 0.0;
	if (indices[i] > 0) {
	  p = static_cast<double>(indices[i-1]+1.0)/k; // last case landed on field with count i-1
	  ASSERT((p >= 0) && (p <= 1.0));
	  newCounts[i-1] = indices[i-1] + 1; 
          newCounts[i] = indices[i] - 1; 
	  result += p * exactProb(newCounts);
	  ASSERT(result >= 0.0); 
          p0 -= p; 
	  if (((i + 1) == indices.size()) && (indices[i] > 0)) {
	    p = static_cast<double>(indices[i])/k; // last case landed on highest counter
	    ASSERT((p >= 0) && (p <= 1.0));
	    newCounts[i-1] = indices[i-1];
	    newCounts[i] = indices[i] - 1;  
	    result += p * exactProb(newCounts);
	    ASSERT(result >= 0.0); 
	    p0 -= p; 
	  }
        }
      }
      newCounts = indices;
      cout << "p0 is " << p0 << endl;

      newCounts[0] = n-1;
      if (newCounts[1] > 0) {
	// imitating formula (k-((m1-1)+m2+m3+m4))) / k (see exactProb for 4 indices) for general case:
	int_type m = indices[1]-1;
	for (int_type mm = 2; mm < static_cast<int_type>(indices.size()); ++mm) {
	  m+= indices[mm];
	}
	if (m < outcomeMax) {
	  double p00 = (outcomeMax - m) / static_cast<double>(outcomeMax); 
	  newCounts[1] -= 1;
	  cout << "p00: prob that last case landed on new field: " << p00 << endl;  
	  result += p00 * exactProb(newCounts); // has landed on new field
	}
      } else {
	cout << "Ignoring weird case for p0: " << p0 << " : " << newCounts << " from " << indices << endl;
      }
    }
    cout << "current result: " << result << endl;
    ASSERT(result < 1.01); // rounding errors are possible but must not exceed 0.1
    if (result > 1.0) {
      result = 1.0;
    }
    ASSERT(result > -0.01); // rounding errors are possible but must not exceed 0.1
    if (result < 0.0) {
      result = 0.0;
    }
    if (USE_MATRICES_MODE) {
      matrices.set(indices, result);
    }
    POSTCOND(result >= 0.0 && (result <= 1.0)); 
    return result;
  }

  /** For k given possibilities (like 365 days), n trials (like n people),
   * what is the probability to observe only m OR LESS different outcomes (like m different birthdays)
   * @param epsilon how accurate
   */
  /*
  double exactP(int_type k, int_type n, int_type m) const {
    if (k > static_cast<int_type>(matrices.getDimension(0))) {
      cout << "# Warning: number of outcomes ( " << k << " ) exceeds internal data structure size of " << matrices.getDimension(0) << " . Cannot compute precise p-value." << endl; 
      return 1.0;
    }
    double result = 0.0;
    for (int_type i = m; i >= 1; --i) {
      result += 0.0; // exactProb(k,n,i);
    }
    if (result > 1.1) {
      cout << "# WARNING: Rounding error too large in BirthdayProbAll : " << k << " " << n << " " << m << " : " << result << endl;
    }
    ASSERT(result < 1.1); // rounding errors are possible but must not exceed 0.1
    if (result > 1.0) {
      result = 1.0;
    }
    ASSERT(result > -0.1); // rounding errors are possible but must not exceed 0.1
    if (result < 0.0) {
      result = 0.0;
    }
    POSTCOND(result >= 0.0 && (result <= 1.0)); 
    return result;
  }
  */

  int_type getCountDim() const { return countDim; }

  int_type getCountMax() const { return countMax; }

  const MultiValArray& getMatrices() const { return matrices; }

  int_type getOutcomeMax() const { return outcomeMax; }

  /** Returns number of possible outcomes. For a classical birthday problem, this would be 365 */
  size_type size() const { return matrices.getDimension(0); }

};

#endif
