#include "cfl/Fit.hpp"
#include "cfl/Error.hpp"
#include <gsl/gsl_multifit.h>
#include <gsl/gsl_bspline.h>
#include <gsl/gsl_blas.h>
#include <algorithm>
#include <numeric>
#include <cmath>

using namespace cfl;
using namespace std;

//class Fit

cfl::Fit::Fit(IFit *pNewP) : m_uP(pNewP) {}

// Checking domains

bool belongs(const std::vector<Function> &rF, double dX)
{
  return std::all_of(rF.begin(), rF.end(),
                     [dX](const Function &rG) {
                       return rG.belongs(dX);
                     });
}

bool belongs(const std::vector<Function> &rF, const std::vector<double> &rArg)
{
  return std::all_of(rArg.begin(), rArg.end(),
                     [&rF](double dX) {
                       return belongs(rF, dX);
                     });
}

FitParam fitParam(const gsl_vector *pC, const gsl_matrix *pCov, double dChi2)
{
  FitParam uParam;
  unsigned iSize = pC->size;
  uParam.fit = std::valarray<double>(pC->data, iSize);
  uParam.cov = std::valarray<double>(pCov->data, iSize*iSize);
  uParam.chi2 = dChi2;
  return uParam;
}

// class LinFit

class LinFit : public cfl::IFit
{
public:
  LinFit(const std::vector<Function> &rBaseF)
    : m_uBaseF(rBaseF),
      m_uCov(gsl_matrix_alloc(rBaseF.size(), rBaseF.size()), &gsl_matrix_free),
      m_uC(gsl_vector_alloc(rBaseF.size()), &gsl_vector_free)
  {
    PRECONDITION(rBaseF.size() > 0);
  }

  LinFit(const std::vector<Function> &rBaseF, const std::vector<double> &rArg,
         const std::vector<double> &rVal,
         const std::vector<double> &rWt, bool bChi2)
    : LinFit(rBaseF)
  {
    PRECONDITION(belongs(rBaseF, rArg));
    PRECONDITION((rArg.size() == rVal.size()) && (rArg.size() > 0) && (rArg.size() == rWt.size()));
    PRECONDITION(rBaseF.size() < rArg.size());

    if (rArg.size() <= rBaseF.size())
      {
	throw(cfl::NError::size("not enough nodes for linear fit"));
      }

    std::unique_ptr<gsl_multifit_linear_workspace, decltype(&gsl_multifit_linear_free)>
      uFit(gsl_multifit_linear_alloc(rArg.size(), m_uBaseF.size()), &gsl_multifit_linear_free);

    std::unique_ptr<gsl_matrix, decltype(&gsl_matrix_free)>
      uX(gsl_matrix_alloc(rArg.size(), m_uBaseF.size()), &gsl_matrix_free);
    for (unsigned i = 0; i < rArg.size(); i++)
      {
	for (unsigned j = 0; j < m_uBaseF.size(); j++)
	  {
	    gsl_matrix_set(uX.get(), i, j, m_uBaseF[j](rArg[i]));
	  }
      }
    gsl_vector_const_view uW = gsl_vector_const_view_array(&rWt.front(), rWt.size());
    gsl_vector_const_view uY = gsl_vector_const_view_array(&rVal.front(), rVal.size());

    gsl_multifit_wlinear(uX.get(), &uW.vector, &uY.vector, m_uC.get(),
                         m_uCov.get(), &m_dChi2, uFit.get());

    if (bChi2)
      {
	double dVar = m_dChi2 / (rArg.size() - rBaseF.size());
	gsl_matrix_scale(m_uCov.get(), dVar);
      }
  }

  IFit *
  newObject(const std::vector<double> &rArg, const std::vector<double> &rVal,
            const std::vector<double> &rWt, bool bChi2) const
  {
    return new LinFit(m_uBaseF, rArg, rVal, rWt, bChi2);
  }

  Function fit() const
  {
    std::function<double(double)> uFit =
      [uBase = m_uBaseF, uC = m_uC](double dX) {
	std::vector<double> uV(uBase.size());
	std::transform(uBase.begin(), uBase.end(), uV.begin(),
		       [dX](const Function &rF) {
			 return rF(dX);
		       });
	gsl_vector_const_view uU = gsl_vector_const_view_array(uV.data(), uV.size());
	gsl_blas_ddot(&uU.vector, uC.get(), &dX);
	return dX;
      };
    return Function(uFit,
                    [uBase = m_uBaseF](double dX) {
                      return belongs(uBase, dX);
                    });
  }

  Function err() const
  {
    std::function<double(double)> uErr =
      [uBase = m_uBaseF, uCov = m_uCov](double dX) {
	std::vector<double> uV(uBase.size());
	std::transform(uBase.begin(), uBase.end(), uV.begin(),
		       [dX](const Function &rF) {
			 return rF(dX);
		       });
	std::vector<double> uU(uV);
	gsl_vector_const_view uUView = gsl_vector_const_view_array(uU.data(), uU.size());
	gsl_vector_view uVView = gsl_vector_view_array(uV.data(), uV.size());
	gsl_blas_dsymv(CblasUpper, 1., uCov.get(), &uUView.vector, 0., &uVView.vector);
	gsl_blas_ddot(&uUView.vector, &uVView.vector, &dX);
	ASSERT(dX >= 0);
	return std::sqrt(dX);
      };
    return Function(uErr,
                    [uBase = m_uBaseF](double dX) {
                      return belongs(uBase, dX);
                    });
  }

  FitParam param() const
  {
    return fitParam(m_uC.get(), m_uCov.get(), m_dChi2);
  }

private:
  std::vector<Function> m_uBaseF;
  std::shared_ptr<gsl_matrix> m_uCov;
  std::shared_ptr<gsl_vector> m_uC;
  double m_dChi2;
};

// linear

cfl::Fit cfl::NFit::linear(const std::vector<Function> &rBaseF)
{
  return Fit(new LinFit(rBaseF));
}

cfl::Fit cfl::NFit::linear(const cfl::Function &rF)
{
  return linear(std::vector<Function>(1, rF));
}

// class BSpline

class BSpline : public IFit
{
public:
  BSpline(unsigned iOrder, const std::vector<double> &rPoints)
    : m_iOrder(iOrder), m_uPoints(rPoints)
  {
    m_iF = iOrder + rPoints.size() - 2;
    m_uBSpline.reset(gsl_bspline_alloc(iOrder, rPoints.size()), &gsl_bspline_free);

    ASSERT(m_iF == gsl_bspline_ncoeffs(m_uBSpline.get()));

    m_uCov.reset(gsl_matrix_alloc(m_iF, m_iF), &gsl_matrix_free);
    m_uC.reset(gsl_vector_alloc(m_iF), &gsl_vector_free);
    gsl_vector_const_view uVec =
      gsl_vector_const_view_array(rPoints.data(), rPoints.size());
    gsl_bspline_knots(&uVec.vector, m_uBSpline.get());
  }

  BSpline(unsigned iOrder, const std::vector<double> &rPoints,
          const std::vector<double> &rArg, const std::vector<double> &rVal,
          const std::vector<double> &rWt, bool bChi2)
    : BSpline(iOrder, rPoints)
  {
    PRECONDITION((rArg.size() == rVal.size()) && (rArg.size() > 0));
    PRECONDITION(rArg.size() == rWt.size());
    PRECONDITION(rArg.size() > m_iF);

    if (rArg.size() <= m_iF)
      {
	throw(NError::size("not enough nodes for fitting with B-splines"));
      }

    std::unique_ptr<gsl_multifit_linear_workspace, decltype(&gsl_multifit_linear_free)>
      uFit(gsl_multifit_linear_alloc(rArg.size(), m_iF), &gsl_multifit_linear_free);

    std::unique_ptr<gsl_matrix, decltype(&gsl_matrix_free)>
      uX(gsl_matrix_alloc(rArg.size(), m_iF), &gsl_matrix_free);
    for (unsigned i = 0; i < rArg.size(); i++)
      {
	gsl_vector_view uV = gsl_matrix_row(uX.get(), i);
	gsl_bspline_eval(rArg[i], &uV.vector, m_uBSpline.get());
      }
    gsl_vector_const_view uW = gsl_vector_const_view_array(rWt.data(), rWt.size());
    gsl_vector_const_view uY = gsl_vector_const_view_array(rVal.data(), rVal.size());

    gsl_multifit_wlinear(uX.get(), &uW.vector, &uY.vector, m_uC.get(),
                         m_uCov.get(), &m_dChi2, uFit.get());

    if (bChi2)
      {
	double dVar = m_dChi2 / (rArg.size() - m_iF);
	gsl_matrix_scale(m_uCov.get(), dVar);
      }
  }

  IFit *newObject(const std::vector<double> &rArg, const std::vector<double> &rVal,
                  const std::vector<double> &rWt, bool bChi2) const
  {
    return new BSpline(m_iOrder, m_uPoints, rArg, rVal, rWt, bChi2);
  }

  Function fit() const
  {
    std::function<double(double)> uFit =
      [uC = m_uC, uBS = m_uBSpline, iK = m_iOrder](double dX) {
	std::vector<double> uF(iK);
	gsl_vector_view uViewF = gsl_vector_view_array(uF.data(), uF.size());
	unsigned long iStart, iEnd;
	gsl_bspline_eval_nonzero(dX, &uViewF.vector, &iStart, &iEnd, uBS.get());
	ASSERT(iStart + iK == iEnd + 1);
	gsl_vector_view uViewG = gsl_vector_subvector(uC.get(), iStart, iK);
	gsl_blas_ddot(&uViewG.vector, &uViewF.vector, &dX);
	return dX;
      };
    return Function(uFit, m_uPoints.front(), m_uPoints.back());
  }

  Function err() const
  {
    std::function<double(double)> uErr =
      [uCov = m_uCov, uB = m_uBSpline, iK = m_iOrder](double dX) {
	PRECONDITION(uCov->size1 == uCov->size2);
	std::vector<double> uF(iK);
	gsl_vector_view uFView = gsl_vector_view_array(uF.data(), uF.size());
	unsigned long iStart, iEnd;
	gsl_bspline_eval_nonzero(dX, &uFView.vector, &iStart, &iEnd, uB.get());
	ASSERT(iStart + iK == iEnd + 1);
	std::vector<double> uG(uF);
	gsl_vector_view uGView = gsl_vector_view_array(uG.data(), uG.size());
	gsl_matrix_view uCovView = gsl_matrix_submatrix(uCov.get(), iStart, iStart, iK, iK);
	gsl_blas_dsymv(CblasUpper, 1., &uCovView.matrix, &uGView.vector, 0., &uFView.vector);
	gsl_blas_ddot(&uFView.vector, &uGView.vector, &dX);
	return std::sqrt(dX);
      };
    return Function(uErr, m_uPoints.front(), m_uPoints.back());
  }

  FitParam param() const
  {
    return fitParam(m_uC.get(), m_uCov.get(), m_dChi2);
  }

private:
  unsigned m_iOrder, m_iF;
  std::vector<double> m_uPoints;
  std::shared_ptr<gsl_bspline_workspace> m_uBSpline;
  std::shared_ptr<gsl_matrix> m_uCov;
  std::shared_ptr<gsl_vector> m_uC;
  double m_dChi2;
};

//bspline

cfl::Fit cfl::NFit::bspline(unsigned iOrder, const std::vector<double> &rPoints)
{
  return Fit(new BSpline(iOrder, rPoints));
}

cfl::Fit
cfl::NFit::bspline(unsigned iOrder, double dL, double dR, unsigned iPoints)
{
  PRECONDITION(dL < dR);
  PRECONDITION(iPoints > 1);
  double dS = (dR - dL) / (iPoints - 1.);
  std::vector<double> uPoints(iPoints, dL);
  uPoints.back() = dR;
  std::transform(uPoints.begin(), uPoints.end() - 2, uPoints.begin() + 1,
                 [dS](double dX) { return dX + dS; });
  POSTCONDITION(std::equal(uPoints.begin(), uPoints.end() - 1,
                           uPoints.begin() + 1, std::less<double>()));
  POSTCONDITION(uPoints.front() == dL);
  POSTCONDITION(uPoints.back() == dR);

  return Fit(new BSpline(iOrder, uPoints));
}
