#include "test/Main.hpp"
#include "test/Data.hpp"
#include "test/Print.hpp"
#include "test/Black.hpp"
#include "Session2/Output.hpp"
#include "Session2/Session2.hpp"

using namespace test;
using namespace cfl;
using namespace std;
using namespace test::Data;

// INTERPOLATION OF DATA CURVES

void forwardLogLinInterp()
{
  test::print("LOG LINEAR INTERPOLATION OF FORWARD CURVE");

  double dSpot = 100;
  double dInitialTime = 1.;

  auto uF = test::Data::getForward(dSpot, dInitialTime);

  Function uForward =
    prb::forwardLogLinInterp(dSpot, uF.first, uF.second,
			     dInitialTime);

  double dInterval = uF.first.front() - dInitialTime;
  test::Data::print(uForward, dInitialTime, dInterval);
}

void discountYieldInterp()
{
  test::print("DISCOUNT CURVE BY INTERPOLATION OF YIELDS");

  double dInitialTime = 1.;

  auto uDF = test::Data::getDiscount(dInitialTime);
  double dR = (1 / uDF.second.front() - 1.) / (uDF.first.front() - dInitialTime);
  test::print(dR, "initial short-term rate", true);

  Function uDiscount =
    prb::discountYieldInterp(uDF.first, uDF.second, dR, dInitialTime, cfl::NInterp::akima());

  double dInterval = uDF.first.front() - dInitialTime;
  print("interpolation with Akima method:");
  test::Data::print(uDiscount, dInitialTime, dInterval);
}

// LEAST-SQUARES FITTING OF DATA CURVES

const std::string c_sForward("Fitted forward prices and their errors:");
const std::string c_sConstCarry("We fit with constant cost-of-carry rate.");

void forwardCarryFit()
{
  test::print("FORWARD CURVE BY LEAST-SQUARES FIT OF COST-OF-CARRY RATES");

  double dSpot = 100.;
  double dInitialTime = 1.;

  auto uF = test::Data::getForward(dSpot, dInitialTime);

  print(c_sConstCarry);
  Fit uFit = NFit::linear(Function(1.));

  Function uErr;
  Function uForward =
      prb::forwardCarryFit(dSpot, uF.first, uF.second, dInitialTime, uFit, uErr);

  double dInterval = uF.first.front() - dInitialTime;
  test::Data::printFit(uFit.param());
  test::Data::print(c_sForward, uForward, uErr, dInitialTime, dInterval);
}

void forwardBlackFit()
{
  test::print("LEAST-SQUARES FIT OF FORWARD CURVE IN BLACK MODEL");

  double dSpot = test::Black::c_dSpot;
  double dLambda = test::Black::c_dLambda;
  double dSigma = test::Black::c_dSigma;
  double dInitialTime = test::Black::c_dInitialTime;

  print(dLambda, "lambda");
  print(dSigma, "sigma");
  auto uF = test::Data::getForward(dSpot, dInitialTime);

  Function uErr;
  FitParam uParam;
  Function uForward =
      prb::forwardBlackFit(dSpot, uF.first, uF.second, dLambda, dSigma, dInitialTime, uErr, uParam);

  double dInterval = uF.first.back() - dInitialTime;
  test::Data::printFit(uParam);
  test::Data::print(c_sForward, uForward, uErr, dInitialTime, dInterval);
}

std::function<void()> test_Session2()
{
  return []() {
	   print("INTERPOLATION OF DATA CURVES");
    
	   forwardLogLinInterp();
	   discountYieldInterp();
    
	   print("LEAST-SQUARES FITTING OF DATA CURVES");
	   
	   forwardCarryFit();
	   forwardBlackFit();
	 };
}

int main()
{
  project(test_Session2(), PROJECT_NAME, PROJECT_NAME,
          "Session 2");
}
