#include <limits>
#include "cfl/BlackModel.hpp"
#include "cfl/Similar.hpp"
#include "cfl/Error.hpp"
#include "cfl/Data.hpp"

using namespace cfl::Black;
using namespace cfl;

// class Black::Data
cfl::Black::Data::Data(const Function &rDiscount,
                       const Function &rForward,
                       const Function &rVolatility,
                       double dInitialTime)
    : m_uDiscount(rDiscount),
      m_uForward(rForward),
      m_uVol(rVolatility),
      m_uShape(1.),
      m_dInitialTime(dInitialTime) {}

cfl::Black::Data::Data(const Function &rDiscount,
                       const Function &rForward,
                       double dSigma,
                       double dInitialTime)
    : m_uDiscount(rDiscount),
      m_uForward(rForward),
      m_uVol(dSigma),
      m_uShape(1.),
      m_dInitialTime(dInitialTime) {}

cfl::Black::Data::Data(const Function &rDiscount,
                       const Function &rForward,
                       const Function &rVolatility,
                       const Function &rShape,
                       double dInitialTime)
    : m_uDiscount(rDiscount),
      m_uForward(rForward),
      m_uVol(rVolatility),
      m_uShape(rShape),
      m_dInitialTime(dInitialTime) {}

cfl::Black::Data::Data(const Function &rDiscount,
                       const Function &rForward,
                       double dSigma,
                       double dLambda,
                       double dInitialTime)
    : m_uDiscount(rDiscount), m_uForward(rForward),
      m_uVol(cfl::Data::volatility(dSigma, dLambda, dInitialTime)),
      m_uShape(cfl::Data::assetShape(dLambda, dInitialTime)),
      m_dInitialTime(dInitialTime) {}

// construction of Black model
namespace cflBlack
{
  class Rollback : public ISimilarRollback
  {
  public:
    Rollback(const IModel &rBModel, const Function &rDiscount)
        : m_rBModel(rBModel), m_rDiscount(rDiscount) {}

    void rollback(Slice &rSlice, unsigned iTime) const
    {
      PRECONDITION(rSlice.timeIndex() >= iTime);
      PRECONDITION(&rSlice.model() == &m_rBModel);

      double dMaturity = m_rBModel.eventTimes()[rSlice.timeIndex()];
      double dToday = m_rBModel.eventTimes()[iTime];
      double dFactor = m_rDiscount(dMaturity) / m_rDiscount(dToday);
      rSlice.rollback(iTime);
      rSlice *= dFactor;
    }

    const IModel &model() const
    {
      return m_rBModel;
    }

  private:
    const IModel &m_rBModel;
    const Function &m_rDiscount;
  };

  class Model : public IAssetModel
  {
  public:
    Model(const Black::Data &rData, const std::vector<double> &rEventTimes,
          double dInterval, const Brownian &rBrownian)
        : m_uData(rData), m_dInterval(dInterval), m_uBrownian(rBrownian)
    {
      ASSERT(rEventTimes.front() == rData.initialTime());

      std::vector<double> uVar(rEventTimes.size());
      std::transform(rEventTimes.begin(), rEventTimes.end(), uVar.begin(),
                     [&rData](double dTime) {
                       return std::pow(rData.volatility()(dTime), 2);
                     });
      m_uBrownian.assign(uVar, rEventTimes, dInterval);
      m_uModel.assign(new Rollback(m_uBrownian.model(), m_uData.discount()));
    }

    IAssetModel *newModel(const std::vector<double> &rEventTimes) const
    {
      return new Model(m_uData, rEventTimes, m_dInterval, m_uBrownian);
    }

    const IModel &model() const
    {
      return m_uModel.model();
    }

    Slice discount(unsigned iTime, double dMaturity) const
    {
      double dTime = model().eventTimes()[iTime];
      double dFactor = m_uData.discount()(dMaturity) / m_uData.discount()(dTime);
      return Slice(&model(), iTime, dFactor);
    }

    Slice forward(unsigned iTime, double dForwardMaturity) const
    {
      PRECONDITION(iTime < model().eventTimes().size());
      double dRefTime = model().eventTimes()[iTime];
      PRECONDITION(dForwardMaturity >= dRefTime);

      //forward price = exp(shape * state + c);
      double dForward = m_uData.forward()(dForwardMaturity);
      double dVol = m_uData.volatility()(dRefTime);
      double dShape = m_uData.shape()(dForwardMaturity);
      double dC = std::log(dForward) - 0.5 * std::pow(dVol * dShape, 2) * (dRefTime - m_uData.initialTime());
      Slice uState = model().state(iTime, 0);
      return exp(uState * dShape + dC);
    }

  private:
    Black::Data m_uData;
    double m_dInterval;
    Brownian m_uBrownian;
    Similar m_uModel;
  };
} // namespace cflBlack

// function cfl::Black::model
AssetModel cfl::Black::model(const Data &rData, double dInterval,
                             const Brownian &rBrownian)
{
  std::vector<double> uEventTimes(1, rData.initialTime());
  return AssetModel(new cflBlack::Model(rData, uEventTimes, dInterval, rBrownian));
}
