/* Ergo, version 3.3, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2013 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
 * with Linearly Scaling Computational Time and Memory Usage,
 * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
 * J. Chem. Theory Comput. 7, 340 (2011),
 * <http://dx.doi.org/10.1021/ct100611z>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */
#ifdef USE_CHUNKS_AND_TASKS

#include <cstdio>
#include <iostream>
#include <limits>
#include "ComputeOverlapMatRecursive.h"
#include "integrals_general.h"

ergo_real compute_one_element_of_overlap_mat_2bsets(const BasisInfoStruct & b1, int i1, const BasisInfoStruct & b2, int i2, DistributionSpecStruct* wrkBuf, int wrkBufCount) {
  int n1 = b1.noOfBasisFuncs;
  int n2 = b2.noOfBasisFuncs;
  if(i1 < 0 || i1 >= n1 || i2 < 0 || i2 >= n2)
    throw std::runtime_error("Error in compute_one_element_of_overlap_mat_2bsets: bad index.");
  int nPrimitives = 
    get_product_simple_primitives(b1, i1,
                                  b2, i2,
                                  wrkBuf,
                                  wrkBufCount,
                                  0);
  if(nPrimitives <= 0)
    throw std::runtime_error("Error in compute_one_element_of_overlap_mat_2bsets: (nPrimitives <= 0).");
  ergo_real sum = 0;
  for(int k = 0; k < nPrimitives; k++) {
    DistributionSpecStruct* currDistr = &wrkBuf[k];
    sum += compute_integral_of_simple_prim(currDistr);
  }
  return sum;
}


class CombineResultsAcceptingNull : public cht::Task {
public:
  cht::ID execute(const cht::ChunkID &, const cht::ChunkID &);
  CHT_TASK_INPUT((cht::ChunkID, cht::ChunkID));
  CHT_TASK_OUTPUT((CHTMLMatType));
  CHT_TASK_TYPE_DECLARATION;
};

CHT_TASK_TYPE_IMPLEMENTATION((CombineResultsAcceptingNull));

cht::ID CombineResultsAcceptingNull::execute(const cht::ChunkID & cid_matrix_A, 
					     const cht::ChunkID & cid_matrix_B) {
  if(cid_matrix_A == cht::CHUNK_ID_NULL && cid_matrix_B == cht::CHUNK_ID_NULL)
    return cht::CHUNK_ID_NULL;
  if(cid_matrix_A == cht::CHUNK_ID_NULL && cid_matrix_B != cht::CHUNK_ID_NULL)
    return copyChunk(cid_matrix_B);
  if(cid_matrix_A != cht::CHUNK_ID_NULL && cid_matrix_B == cht::CHUNK_ID_NULL)
    return copyChunk(cid_matrix_A);
  // OK, both are non-null. Do addition.
  return registerTask<chtml::MatrixAdd<LeafMatType> >(cid_matrix_A, cid_matrix_B, cht::persistent);
}


static ergo_real getMinDistBetweenBoxes(const BoxStruct & box1, const BoxStruct & box2) {
  ergo_real d[3];
  for(int coordIdx = 0; coordIdx < 3; coordIdx++) {
    ergo_real currDist = 0;
    if(box1.minCoord[coordIdx] > box2.maxCoord[coordIdx])
      currDist = box1.minCoord[coordIdx] - box2.maxCoord[coordIdx];
    if(box2.minCoord[coordIdx] > box1.maxCoord[coordIdx])
      currDist = box2.minCoord[coordIdx] - box1.maxCoord[coordIdx];
    d[coordIdx] = currDist;
  }
  ergo_real minDist = sqrt(d[0]*d[0]+d[1]*d[1]+d[2]*d[2]);
  return minDist;
}

CHT_TASK_TYPE_IMPLEMENTATION((ComputeOverlapMatRecursive));

cht::ID ComputeOverlapMatRecursive::execute(const DistrBasisSetChunk & part1, 
					    const DistrBasisSetChunk & part2, 
					    const chttl::ChunkBasic<MatrixInfoStruct> & info) {
  cht::ChunkID const & cid_part1 = getInputChunkID(part1);
  cht::ChunkID const & cid_part2 = getInputChunkID(part2);
  cht::ChunkID const & cid_info  = getInputChunkID(info);

  // Check if bounding boxes are far enough apart so that this can be skipped.
  ergo_real minDistBetweenBoxes = getMinDistBetweenBoxes(part1.boundingBoxForCenters, part2.boundingBoxForCenters);
  if(minDistBetweenBoxes > part1.maxExtent + part2.maxExtent)
    return cht::CHUNK_ID_NULL;

  bool part1isLowestLevel = false;
  if(part1.noOfBasisFuncs == part1.basisInfo.noOfBasisFuncs)
    part1isLowestLevel = true;
  bool part2isLowestLevel = false;
  if(part2.noOfBasisFuncs == part2.basisInfo.noOfBasisFuncs)
    part2isLowestLevel = true;

  if(part1isLowestLevel && part2isLowestLevel) {
    int n1 = part1.noOfBasisFuncs;
    int n2 = part2.noOfBasisFuncs;
    int nElementsMax = n1 * n2;
    // Create and populate three vectors: rows, cols, values.
    std::vector<int> rows(nElementsMax);
    std::vector<int> cols(nElementsMax);
    std::vector<ergo_real> values(nElementsMax);
    // Check that index lists exist.
    if(part1.basisFuncIndexList.size() != n1 || part2.basisFuncIndexList.size() != n2)
      throw std::runtime_error("Error in ComputeOverlapMatRecursive::execute: basisFuncIndexList has wrong size.");
    const int wrkBufCount = 20000;
    std::vector<DistributionSpecStruct> wrkBuf(wrkBufCount);
    int count = 0;
    for(int i1 = 0; i1 < n1; i1++)
      for(int i2 = 0; i2 < n2; i2++) {
	// Check if these two basis functions are close enough so that there is any point in computing this overlap matrix element.
	const Vector3D & pt1 = part1.basisInfo.basisFuncList[i1].centerCoords;
	const Vector3D & pt2 = part2.basisInfo.basisFuncList[i2].centerCoords;
	ergo_real distance = pt1.dist(pt2);
	ergo_real extent1 = part1.basisFuncExtentList[i1];
	ergo_real extent2 = part2.basisFuncExtentList[i2];
	if(distance < extent1 + extent2) {
	  rows[count] = part1.basisFuncIndexList[i1];
	  cols[count] = part2.basisFuncIndexList[i2];
	  values[count] = compute_one_element_of_overlap_mat_2bsets(part1.basisInfo, i1, part2.basisInfo, i2, &wrkBuf[0], wrkBufCount);
	  count++;
	}
      }
    // Resize vectors to the number of elements we actually got.
    rows.resize(count);
    cols.resize(count);
    values.resize(count);
    // Create corresponding chunk objects.
    cht::ChunkID cid_rows = registerChunk(new chttl::ChunkVector<int>(rows));
    cht::ChunkID cid_cols = registerChunk(new chttl::ChunkVector<int>(cols));
    cht::ChunkID cid_values = registerChunk(new chttl::ChunkVector<double>(values));
    // Create a matrix from the three vectors, and return that matrix chunk.
    // First prepare a params chunk
    int M = info.x.n;
    int N = info.x.n;
    int leavesSizeMax = info.x.leavesSizeMax;
    typename LeafMatType::Params leaf_params;
    // NOTE: internal blocksize does not exist for basic matrix lib, so we cannot set it then.
#ifdef USE_CHUNKS_AND_TASKS_BSM
    leaf_params.blocksize = info.x.leafInternalBlocksize;
#endif
    cht::ChunkID cid_param = registerChunk(new chtml::MatrixParams<LeafMatType>(M, N, leavesSizeMax, 0, 0, leaf_params));
    // Now register task, and return resulting TaskID.
    return registerTask<chtml::MatrixAssignFromSparse<LeafMatType> >(cid_param, cid_rows, cid_cols, cid_values, cht::persistent);
  }
  if(part1isLowestLevel && !part2isLowestLevel) {
    // Register two new tasks and add up the results.
    cht::ID id1 = registerTask<ComputeOverlapMatRecursive>(cid_part1, part2.cid_child_chunks[0], cid_info);
    cht::ID id2 = registerTask<ComputeOverlapMatRecursive>(cid_part1, part2.cid_child_chunks[1], cid_info);
    return registerTask<CombineResultsAcceptingNull>(id1, id2, cht::persistent);
  }
  if(!part1isLowestLevel && part2isLowestLevel) {
    // Register two new tasks and add up the results.
    cht::ID id1 = registerTask<ComputeOverlapMatRecursive>(part1.cid_child_chunks[0], cid_part2, cid_info);
    cht::ID id2 = registerTask<ComputeOverlapMatRecursive>(part1.cid_child_chunks[1], cid_part2, cid_info);
    return registerTask<CombineResultsAcceptingNull>(id1, id2, cht::persistent);
  }
  // Now we know neither part1 nor part2 has reached lowest level.
  // Register four new tasks and add up the results.
  cht::ID id00 = registerTask<ComputeOverlapMatRecursive>(part1.cid_child_chunks[0], part2.cid_child_chunks[0], cid_info);
  cht::ID id01 = registerTask<ComputeOverlapMatRecursive>(part1.cid_child_chunks[0], part2.cid_child_chunks[1], cid_info);
  cht::ID id10 = registerTask<ComputeOverlapMatRecursive>(part1.cid_child_chunks[1], part2.cid_child_chunks[0], cid_info);
  cht::ID id11 = registerTask<ComputeOverlapMatRecursive>(part1.cid_child_chunks[1], part2.cid_child_chunks[1], cid_info);
  cht::ID tmp1 = registerTask<CombineResultsAcceptingNull>(id00, id01);
  cht::ID tmp2 = registerTask<CombineResultsAcceptingNull>(id10, id11);
  return registerTask<CombineResultsAcceptingNull>(tmp1, tmp2, cht::persistent);
}


#endif
