//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Device/Histo/DiffUtil.cpp
//! @brief     Implements namespace DataUtil.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Device/Histo/DiffUtil.h"
#include "Base/Axis/Frame.h"
#include "Base/Axis/Scale.h"
#include "Base/Math/Numeric.h"
#include "Base/Util/Assert.h"
#include "Device/Data/Datafield.h"
#include "Device/Histo/SimulationResult.h"
#include <algorithm>
#include <cmath>
#include <iostream>

//! Returns relative difference between two data sets sum(dat[i] - ref[i])/ref[i]).
double DiffUtil::meanRelVecDiff(const std::vector<double>& dat, const std::vector<double>& ref)
{
    ASSERT(dat.size() == ref.size());
    double diff = 0;
    for (size_t i = 0; i < dat.size(); ++i)
        diff += Numeric::relativeDifference(dat[i], ref[i]);
    diff /= dat.size();
    ASSERT(!std::isnan(diff));
    return diff;
}

Datafield* DiffUtil::relativeDifferenceField(const Datafield& dat, const Datafield& ref)
{
    ASSERT(dat.frame().hasSameSizes(ref.frame()));
    std::vector<double> out(dat.size());
    for (size_t i = 0; i < dat.size(); ++i)
        out[i] = Numeric::relativeDifference(dat[i], ref[i]);
    return new Datafield(dat.frame().clone(), out);
}

//! Returns sum of relative differences between each pair of elements:
//! (a, b) -> 2*abs(a - b)/(|a| + |b|)      ( and zero if  a=b=0 within epsilon )
double DiffUtil::meanRelativeDifference(const SimulationResult& dat, const SimulationResult& ref)
{
    if (dat.size() != ref.size())
        throw std::runtime_error("Invalid call to meanRelativeDifference: "
                                 "different number of elements in dat and ref datasets");
    if (dat.empty())
        throw std::runtime_error("Invalid call to meanRelativeDifference: "
                                 "empty dat and ref datasets");

    double sum_of_diff = 0.;
    double sum_of_fdat = 0.;
    double sum_of_fref = 0.;
    for (size_t i = 0; i < dat.size(); ++i) {
        sum_of_diff += Numeric::relativeDifference(dat[i], ref[i]);
        sum_of_fdat += fabs(dat[i]);
        sum_of_fref += fabs(ref[i]);
    }
    if (sum_of_fdat == 0 && sum_of_fref)
        throw std::runtime_error("Invalid call to meanRelativeDifference: "
                                 "dat and ref only contain zeroes");
    if (sum_of_fdat == 0)
        throw std::runtime_error("Invalid call to meanRelativeDifference: "
                                 "dat only contains zeroes");
    if (sum_of_fref == 0)
        throw std::runtime_error("Invalid call to meanRelativeDifference: "
                                 "ref only contains zeroes");
    return sum_of_diff / dat.size();
}

//! Returns true is relative difference is below threshold; prints informative output
bool DiffUtil::checkRelativeDifference(const std::vector<double>& dat,
                                       const std::vector<double>& ref, const double threshold)
{
    if (*std::min_element(dat.begin(), dat.end()) == 0
        && *std::max_element(dat.begin(), dat.end()) == 0) {
        std::cerr << "FAILED: simulated data set is empty" << std::endl;
        return false;
    }

    try {
        const double diff = DiffUtil::meanRelVecDiff(dat, ref);
        if (diff > threshold) {
            std::cerr << "FAILED: relative deviation of dat from ref is " << diff
                      << ", above given threshold " << threshold << std::endl;
            return false;
        }
        if (diff)
            std::cerr << "- OK: relative deviation of dat from ref is " << diff
                      << ", within given threshold " << threshold << std::endl;
        else
            std::cout << "- OK: dat = ref\n";
        return true;
    } catch (...) {
        return false;
    }
}
