#include <MAFSearchTables3.h>
#include <NucleotideTools.h>
#include <debug.h>
#include <limits.h>
#include <ContainerTools.h>

#define REMARK cout << "# "

 // given a list of hash tables, find the one with 
 // attributes "char1", "char2", "assembly1", "assembly2" set to c1, c2, assembly1, assembly2
 // careful: does not find other way round!
 string
 MAFSearchTables3::createHashTableHash(const string& assembly1, const string& assembly2, const string& assembly3,
				       char c1, char c2, char c3) {
   // REMARK << "Called createHashTableHash with " << assembly1 << " " << assembly2 << " " << assembly3 << " " << c1 << " " << c2 << " " << c3 << endl;
   ASSERT(assembly1.size() > 0);
   ASSERT(assembly2.size() > 0);
   ASSERT(assembly1 != assembly2);
   ASSERT(assembly1 != assembly3);
   ASSERT(assembly2 != assembly3);
   ASSERT(!((c1 == c2) && (c1 == c3)));
   string result;
   string sep = "_";
   if (assembly2 < assembly1) {
     if  (assembly1 < assembly3) {
       result = createHashTableHash(assembly2, assembly1, assembly3, c2, c1, c3);
     } else {
       result = createHashTableHash(assembly2, assembly3, assembly1, c2, c3, c1);
     }
   } else if (assembly3 < assembly1) {
     if  (assembly1 < assembly2) {
       result = createHashTableHash(assembly3, assembly1, assembly2, c3, c1, c2);
     } else {
       result = createHashTableHash(assembly3, assembly2, assembly1, c3, c2, c1);
     }
   } else if ((assembly1 < assembly3) && (assembly3 < assembly2)) {
     result = createHashTableHash(assembly1, assembly3, assembly2, c1, c3, c2);
   } else {
     ASSERT(assembly1 < assembly2);
     ASSERT(assembly2 < assembly3);
     c1 = toupper(c1);
     c2 = toupper(c2);
     c3 = toupper(c3);
     result = assembly1 + sep + assembly2 + sep + assembly3 + c1 + sep + c2 + sep + c3;
   };
   ASSERT(result.size() > 0);
   return result;
 }

 void
 MAFSearchTables3::testCreateHashTableHash() {
   string assembly1 = "hg18";
   string assembly2 = "mm8";
   string assembly3 = "panTro";
   char c1 = 'A';
   char c2 = 'c';
   char c3 = 'A';
   ASSERT(createHashTableHash(assembly1,assembly2, assembly3, c1, c2, c3) == "hg18_mm8_panTro_A_C_A");
 }

 /** Generates search tables that return for keys like "hg18_mm8_A_C" all column ids, that have a certain two-character motif
  */
 void
 MAFSearchTables3::createSearchHashTable(MAFAlignment *maf, const set<string>& assemblies, const string& refAssembly) {
   PRECOND(positionHashes.size() == 0);
   PRECOND(assemblies.size() > 2);
   if (verbose > 0) {
     REMARK << "Creating hash-tables for " <<  assemblies.size() << " assemblies: ";
     for (set<string>::const_iterator it = assemblies.begin(); it != assemblies.end(); it++) {
       cout << (*it) << " ";
     }
     REMARK << endl;
   }
   const string& residues = maf->getResidues();
   if (verbose > 1) {
     REMARK << "Generating column index vectors..." << endl;
   };
   Vec<RankedSolution6<string, string, string> > queue;
   queue.reserve((assemblies.size() * (assemblies.size() - 1) ) / 2);
   //    for (set<string>::const_iterator i = assemblies.begin(); i != assemblies.end(); i++) {
   //      ASSERT((*i).size() > 0);
   // if (*i != refAssembly) {
   // continue;
   // }
     // loop over assembly 2
   for (set<string>::const_iterator j = assemblies.begin(); j != assemblies.end(); j++) {
     ASSERT((*j).size() > 0);
     if (*j == refAssembly) {
       continue;
     }
     for (set<string>::const_iterator k = j; k != assemblies.end(); k++) {
       ASSERT((*k).size() > 0);
       if ((j == k) || (*k == refAssembly)) {
	 continue; // do not allow same assemblies
       }
       double score = estimateAssemblyTripleHashSize(refAssembly,*j, *k); // the smaller estimated size the better
       RankedSolution6<string, string, string> item(score, refAssembly, *j, *k);
       queue.push_back(item);
     }
   }
   // }
   sort(queue.begin(), queue.end());
   if (verbose > 2) {
     REMARK << "Successfully sorted " << queue.size() << " assembly triplets." << endl;
     for (size_type i = 0; i < queue.size(); ++i) {
       REMARK << (i+1) << " " << queue[i] << endl;
       if (i > 10) {
	 REMARK << "..." << endl;
	 break;
       }
     }
   }
   size_type numAssemUsed = static_cast<size_type>(assemblyPairFraction * queue.size());
   ASSERT(numAssemUsed > 0);
   map<string, string> assemSeqs;
   set_type dummySet;
   Vec<Vec<Vec<set_type> > > setReferences(residues.size(),
       Vec<Vec<set_type> >(residues.size(),
	   Vec<set_type>(residues.size(), dummySet))); // , &dummySet));

   Vec<Vec<int> > charHashes(256); // again, avoid looking up the mapping from character to the numbers 0...3 in a costly fashion
   int ai = static_cast<int>('A');
   int ap = 0;
   int ci = static_cast<int>('C');
   int cp = 1;
   int gi = static_cast<int>('G');
   int gp = 2;
   int ti = static_cast<int>('T');
   int tp = 3;
   switch (ambiguityMode) {
   case NO_AMBIGUITY:
     for (string::size_type j = 0; j < residues.size(); ++j) {
       charHashes[static_cast<int>(residues[j])] = Vec<int>(1, j);
       // charHash[static_cast<int>(residues[j])] = j;
     }
     break;
   case COMPLEMENT_AMBIGUITY: // consequences of allowing G-U base-pairing: add positions into multiple hash tables!
     ERROR_IF(residues != "ACGT", "Complement-ambiguity only implemented for ACGT alphabet, sorry.");
     {
       charHashes[ai] = Vec<int>(1); // user searched with complement of A, that is U. So G is ambiguity
       charHashes[ai][0] = ap;
       // charHashes[ai][1] = gp;
       charHashes[ci] = Vec<int>(1); // user searched with complement of C, that is G. So T (U) is ambiguity
       charHashes[ci][0] = cp;
       // charHashes[ci][1] = tp;
       charHashes[gi] = Vec<int>(2); // user searched with complement of G, that is C or T(U). So A is an ambiguity!
       charHashes[gi][0] = gp;
       charHashes[gi][1] = ap;
       charHashes[ti] = Vec<int>(2); // user searched with complement of T, that is A or G.  So C is an ambiguity!
       charHashes[ti][0] = tp;
       charHashes[ti][1] = cp;
     }
     break;
   case MATCH_AMBIGUITY:
     ERROR("Match-ambiguity not net implemented.");
     break;
   default:
     ERROR("Invalid ambiguity mode.");
   }
   ASSERT(charHashes[static_cast<int>('A')][0] == 0);
   ASSERT(charHashes[static_cast<int>('C')][0] == 1);
   ASSERT(charHashes[static_cast<int>('G')][0] == 2);
   ASSERT(charHashes[static_cast<int>('T')][0] == 3);
   
   for (size_type i = 0; i < numAssemUsed; ++i) { // loop over triplets
     string assem1 = queue[i].getSecond();
     string assem2 = queue[i].getThird();
     string assem3 = queue[i].getFourth();
     ASSERT(assem1.size() > 0);
     ASSERT(assem2.size() > 0);
     ASSERT(assem3.size() > 0);
     ASSERT(assem1 != assem2);
     ASSERT(assem1 != assem3);
     ASSERT(assem2 != assem3);
     if (assemSeqs.find(assem1) == assemSeqs.end()) {
       assemSeqs[assem1] = maf->generateAssemblySequence(assem1);
     }
     if (assemSeqs.find(assem2) == assemSeqs.end()) {
       assemSeqs[assem2] = maf->generateAssemblySequence(assem2);
     }
     if (assemSeqs.find(assem3) == assemSeqs.end()) {
       assemSeqs[assem3] = maf->generateAssemblySequence(assem3);
     }
     ASSERT(assemSeqs.find(assem1) != assemSeqs.end());
     ASSERT(assemSeqs.find(assem2) != assemSeqs.end());
     ASSERT(assemSeqs.find(assem3) != assemSeqs.end());
     ASSERT(static_cast<length_type>(assemSeqs[assem1].size()) == maf->getTotalLength());
     ASSERT(static_cast<length_type>(assemSeqs[assem2].size()) == maf->getTotalLength());
     ASSERT(static_cast<length_type>(assemSeqs[assem3].size()) == maf->getTotalLength());
     ASSERT(assem1 != assem2);
     if (verbose > 1) {
       REMARK << "Working on assembly triple " << assem1 << " " << assem2 << " " << assem3 << " score: " << queue[i].getFirst() << endl;
     }
     // loop over character 1
     const string& seq1 = assemSeqs[assem1];
     const string& seq2 = assemSeqs[assem2];
     const string& seq3 = assemSeqs[assem3];
     //  string::size_type n = seq1.size();
     // TRICKY Section!
     // The idea is, to avoid looking up the correct set using "createHashTableHash" for each residue. 
     // Instead we generate a 4x4 array of pointers to the correct sets.
     // the array is filled according to the "residues" variable order.
     // Example: setReferences[2][3][0] contains a pointer to the set corresponding to the current assembly triplet and the residues 'G' and 'T' and 'A' (assuming A,C,G,T alphabet)

     length_type countDiff = 0;
     ASSERT(searchRangeMax <= maf->getTotalLength());
     ASSERT(searchRangeMin >= 0);
     for (length_type j = searchRangeMin; j < searchRangeMax; ++j) { // loop over whole searched region
       if (! ((seq1[j] == seq2[j]) && (seq1[j] == seq3[j]))) {
	 ++countDiff;
       }
     }
     size_type rs = residues.size();
     countDiff /= (rs * rs * rs); // estimate number of hits for each set
     for (size_type j = 0; j < rs; ++j) {
       ASSERT(setReferences[j].size() == rs);
       for (size_type k = 0; k < rs; ++k) {
	 ASSERT(setReferences[j][k].size() == rs);
	 for (size_type m = 0; m < rs; ++m) {
	   if (countDiff > static_cast<length_type>(setReferences[j][k][m].size())) {
	     setReferences[j][k][m].reserve(countDiff);
	   }
	   setReferences[j][k][m].clear();
	 }
       }
     }

     int resId1,resId2,resId3;
     for (length_type j = searchRangeMin; j < searchRangeMax; ++j) { // loop over whole searched region
       const Vec<int>& resIds1 = charHashes[static_cast<int>(seq1[j])];
       if(resIds1.size() == 0) {
	 // cout << "Could not use sequence character " << seq1[j] << endl;
	 continue;
       }
       ASSERT(resIds1.size() > 0 && (resIds1.size() <= 2));
       const Vec<int>& resIds2 = charHashes[static_cast<int>(seq2[j])];
       if(resIds2.size() == 0) {
	 // cout << "Could not use sequence character " << seq2[j] << endl;
	 continue;
       }
       ASSERT(resIds2.size() > 0 && (resIds2.size() <= 2));
       const Vec<int>& resIds3 = charHashes[static_cast<int>(seq3[j])];
       if(resIds3.size() == 0) {
	 // cout << "Could not use sequence character " << seq3[j] << endl;
	 continue;
       }
       if (NucleotideTools::isConserved(maf->getSlice(j, assemblies))) {
	 continue;
       }
       ASSERT(resIds3.size() > 0 && (resIds3.size() <= 2));
       if (sameNucShortcut 
	   && (resIds1[0] == resIds2[0]) 
	   && (resIds1[0] == resIds3[0])) {
	 continue; // ignore triplets that without ambiguity considerations are conserved
       }
       // this inner loop is only necesary due to match-ambiguities; the complement of G could be C OR T in case
       // T stands for an RNA U
       for (Vec<int>::size_type i1 = 0; i1 < resIds1.size(); ++i1) {
	 resId1 = resIds1[i1];
	 for (Vec<int>::size_type i2 = 0; i2 < resIds2.size(); ++i2) {
	   resId2 = resIds2[i2];
	   for (Vec<int>::size_type i3 = 0; i3 < resIds3.size(); ++i3) {
	     resId3 = resIds3[i3];
	     if ((resId1 >= 0) && (resId2 >= 0) && (resId3 >= 0) && (!((resId1 == resId2) && (resId1 == resId3)))) {
	       ASSERT(resId1 < static_cast<int>(setReferences.size()));
	       ASSERT(resId2 < static_cast<int>(setReferences[resId2].size()));
	       ASSERT(resId3 < static_cast<int>(setReferences[resId3].size()));
	       ASSERT(! ((resId1 == resId2) && (resId1 == resId3)));
	       // ASSERT(resId1 != resId2); 
	       // setReferences[resId1][resId2]->insert(j);
	       // check if no duplicate:
	       if ((setReferences[resId1][resId2][resId3].size() == 0) 
		   || (setReferences[resId1][resId2][resId3][setReferences[resId1][resId2][resId3].size()-1] != j)) {
		 setReferences[resId1][resId2][resId3].push_back(j);
		 if (verbose > 5) {
		   REMARK << "Added position " << (j+1) << " to " 
			  << assem1 << " " << assem2 << " " << assem3 << " " 
			  << residues[resId1]<<residues[resId2]<<residues[resId3] << endl;
		 }
	       }
	     } // else : not found
	   }
	 }
       }
     }
     for (size_type j = 0; j < rs; ++j) {
       for (size_type k = 0; k < rs; ++k) {
	 for (size_type m = 0; m < rs; ++m) {
	   if ((j == k) && (j == m)) { // not all the same residues
	     continue;
	   }
	   ASSERT(assem1.size() > 0);
	   ASSERT(assem2.size() > 0);
	   ASSERT(assem3.size() > 0);
	   ASSERT(!((residues[j] == residues[k]) && (residues[j] == residues[m])));
	   positionHashes[createHashTableHash(assem1,assem2,assem3,residues[j],residues[k], residues[m])] = 
	     setReferences[j][k][m];
	   ERROR_IF(!ContainerTools::isSorted(setReferences[j][k][m].begin(), setReferences[j][k][m].end()),
		    "Internal error: Set was not sorted.");
	   // SortedCompressedSequence(setReferences[j][k][m]);
	   // compressSet(setReferences[j][k]); // store COMPRESSED set!
	   // setReferences[j][k].clear(); // not needed anymore
	 }
       }
     }
     // some more detailed output in verbose > 1 mode:
     if (verbose > 1) {
       size_type sz = 0;
       for (size_type j = 0; j < residues.size(); ++j) {
	 for (size_type k = 0; k < residues.size(); ++k) {
	   for (size_type m = 0; m < residues.size(); ++m) {
	     if ((j == k) && (j == m)) {
	       continue;
	     }
	     ASSERT(assem1.size() > 0);
	     ASSERT(assem2.size() > 0);
	     ASSERT(assem3.size() > 0);
	     ASSERT(!((residues[j] == residues[k]) && (residues[j] == residues[m])));
	     string hashhash = createHashTableHash(assem1,assem2,assem3, residues[j],residues[k], residues[m]);
	     ASSERT(positionHashes.find(hashhash) != positionHashes.end());
	     sz += positionHashes[hashhash].size();
	     if (verbose > 4) {
	       REMARK << "Uncompressing " << positionHashes[hashhash] << endl;
	       set_type outset = uncompressSet(positionHashes[hashhash]);
	       REMARK << "Content of search table set " << hashhash << " : ";
	       for (set_type::iterator it = outset.begin(); it != outset.end(); ++it) {
		 cout << (*it) << " ";
	       }
	       cout << endl;
	     }
	   }
	 }
       }
       REMARK << " Actual size of " << ((residues.size()*(residues.size()-1))) << " sets: " << sz << endl;
     }
   }
   // resetPositionHashStarts();
   if (verbose > 0) {
     REMARK << "Generated " << positionHashes.size() << " hash tables." << endl;
   }
   // return result;
 }

/** Estimates the potential number of hash table entries of two assemblies
 */
double
MAFSearchTables3::estimateAssemblyTripleHashSize(const string& assem1, const string& assem2, const string& assem3) const {
  set<string> assems;
  assems.insert(assem1);
  assems.insert(assem2);
  assems.insert(assem3);
  length_type numDiv = 1000;
  length_type stride = maf->getTotalLength() / numDiv;
  if (stride == 0) {
    stride = 1;
  }
  length_type countStored = 0;
  length_type countTotal = 0;
  for (length_type colid = 0; colid < maf->getTotalLength(); colid += stride, ++countTotal) {
    string slice = maf->getSlice(colid, assems);
    ASSERT(slice.size() == 3);
    if (! ((slice[0] == slice[1]) && (slice[0] == slice[2]))) {
      if (!(NucleotideTools::isGap(slice[0]) || NucleotideTools::isGap(slice[1]) || NucleotideTools::isGap(slice[2]) ) ) {
	++countStored;
      }
    }
  }
  // use pseudocounts for fraction:
  double result = (static_cast<double>(countStored+1)/static_cast<double>(countTotal+2)) * maf->getTotalLength();
  ASSERT((result >= 0) && (result <= maf->getTotalLength()));
  return result;
}
