#include <limits>
#include "cfl/HullWhiteModel.hpp"
#include "cfl/Error.hpp"
#include "cfl/Similar.hpp"
#include "cfl/Data.hpp"

using namespace cfl::HullWhite;
using namespace cfl;

//class HullWhite::Data
cfl::HullWhite::Data::Data(const cfl::Function &rDiscount,
                           const Function &rVolatility,
                           const Function &rShape,
                           double dInitialTime)
    : m_uDiscount(rDiscount), m_uVol(rVolatility),
      m_uShape(rShape), m_dInitialTime(dInitialTime)
{
}

cfl::HullWhite::Data::Data(const cfl::Function &rDiscount, double dSigma,
                           double dLambda, double dInitialTime)
    : m_uDiscount(rDiscount),
      m_uVol(cfl::Data::volatility(dSigma, dLambda, dInitialTime)),
      m_uShape(cfl::Data::bondShape(dLambda, dInitialTime)),
      m_dInitialTime(dInitialTime)
{
}

// construction of Hull and White model
namespace cflHullWhite
{
  Slice discount(unsigned iTime, double dMaturity, const cfl::HullWhite::Data &rData,
                 const IModel &rModel)
  {
    PRECONDITION(iTime < rModel.eventTimes().size());
    double dRefTime = rModel.eventTimes()[iTime];
    PRECONDITION(dMaturity >= rModel.eventTimes()[iTime]);

    double dA = rData.shape()(dRefTime);
    double dB = rData.shape()(dMaturity);
    double dC = rData.shape()(rModel.eventTimes().back());
    double dVar = std::pow(rData.volatility()(dRefTime), 2) *
                  (dRefTime - rData.initialTime());
    double dForwardDiscount =
        rData.discount()(dMaturity) / rData.discount()(dRefTime);
    Slice uDiscount = exp(rModel.state(iTime, 0) * (dB - dA));
    uDiscount *= (dForwardDiscount * std::exp(-0.5 * (dB - dA) * (dA + dB - 2. * dC) * dVar));
    return uDiscount;
  }

  class Rollback : public ISimilarRollback
  {
  public:
    Rollback(const IModel &rBModel, const HullWhite::Data &rData)
        : m_rBModel(rBModel), m_rData(rData) {}

    void rollback(Slice &rSlice, unsigned iTime) const
    {
      PRECONDITION(rSlice.timeIndex() >= iTime);
      PRECONDITION(&rSlice.model() == &m_rBModel);

      double dMaturity = model().eventTimes().back();
      rSlice /= discount(rSlice.timeIndex(), dMaturity, m_rData, m_rBModel);
      rSlice.rollback(iTime);
      rSlice *= discount(iTime, dMaturity, m_rData, m_rBModel);
    }

    const IModel &model() const
    {
      return m_rBModel;
    }

  private:
    const IModel &m_rBModel;
    const HullWhite::Data &m_rData;
  };

  class Model : public IInterestRateModel
  {
  public:
    Model(const HullWhite::Data &rData, const std::vector<double> &rEventTimes,
          double dInterval, const Brownian &rBrownian)
        : m_uData(rData), m_dInterval(dInterval), m_uBrownian(rBrownian)
    {
      PRECONDITION(rEventTimes.front() == rData.initialTime());
      std::vector<double> uVar(rEventTimes.size());
      std::transform(rEventTimes.begin(), rEventTimes.end(), uVar.begin(),
                     [&rData](double dTime) {
                       return std::pow(rData.volatility()(dTime), 2);
                     });
      m_uBrownian.assign(uVar, rEventTimes, dInterval);
      m_uModel.assign(new Rollback(m_uBrownian.model(), m_uData));
    }

    IInterestRateModel *newModel(const std::vector<double> &rEventTimes) const
    {
      return new Model(m_uData, rEventTimes, m_dInterval, m_uBrownian);
    }

    const IModel &model() const
    {
      return m_uModel.model();
    }

    Slice discount(unsigned iTime, double dMaturity) const
    {
      double dTime = model().eventTimes()[iTime];
      ASSERT(dTime <= dMaturity);
      if (dTime == dMaturity)
      {
        return Slice(&model(), iTime, 1.);
      }
      return cflHullWhite::discount(iTime, dMaturity, m_uData, model());
    }

  private:
    HullWhite::Data m_uData;
    double m_dInterval;
    Brownian m_uBrownian;
    Similar m_uModel;
  };
} // namespace cflHullWhite

// function cfl::HullWhite::model
InterestRateModel cfl::HullWhite::model(const HullWhite::Data &rData,
                                        double dInterval, const Brownian &rBrownian)
{
  std::vector<double> uEventTimes(1, rData.initialTime());
  return InterestRateModel(new cflHullWhite::Model(rData, uEventTimes, dInterval, rBrownian));
}
