{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "import re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from random import choices\n",
    "import tqdm\n",
    "%matplotlib qt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = np.eye(2)*0.95+(1-0.95)/3 #A reasonably stable binary transition matrix, for testing\n",
    "M = np.array([[4/5, 1/5],[1/5,4/5]]) #A simple emission matrix\n",
    "\n",
    "N= P.shape[0]\n",
    "y_basis = np.eye(M.shape[1])\n",
    "x_basis = np.eye(N)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Simulate a path of X,Y for testing\n",
    "\n",
    "T = 1000\n",
    "mu_0 = [1/2, 1/2]\n",
    "X = [choices(np.arange(N), mu_0)]\n",
    "Y = [[0]]\n",
    "for t in range(1,T):\n",
    "    X.append(choices(np.arange(N), (x_basis[X[-1]].dot(P))[0]))\n",
    "    Y.append(choices(np.arange(M.shape[1]), x_basis[X[-1]].dot(M)[0]))\n",
    "X = np.array(X).T[0]\n",
    "Y = np.array(Y).T[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1f2719fda20>]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plt.plot(Y*0.9 + 0.05)\n",
    "plt.plot(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Wonham_filter(Y, mu_0, M, P):\n",
    "    # Computes the Wonham filter, giving the best estimate of X (ie the probabilities of each state) at each time, given observations\n",
    "    # of Y up to the present time. \n",
    "     \n",
    "    mu = np.zeros((len(Y), len(mu_0)))\n",
    "    mu[0] = mu_0\n",
    "\n",
    "    for t in range(1, len(Y)):\n",
    "        # Compute the conditional probabilities\n",
    "        mu[t] = mu[t-1].dot(P).dot(np.diag(M.dot(y_basis[Y[t]])))\n",
    "\n",
    "        #Normalize the estimates (improves stability)\n",
    "        mu[t]= mu[t]/mu[t].sum()\n",
    "\n",
    "    return mu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1f271a718a0>]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mu = Wonham_filter(Y, mu_0, M,P)\n",
    "plt.plot(mu.T[1])\n",
    "plt.plot(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Viterbi_estimator(Y, mu_0, M, P):\n",
    "    # Uses the Viterbi algorithm to compute the most likely sequence of X values given the observations Y\n",
    "\n",
    "    chi = [[x] for x in range(N)]\n",
    "    pi = np.array(mu_0)\n",
    "\n",
    "    for t in range(1,len(Y)):\n",
    "        # Iterate through each potential terminal state x\n",
    "        new_chi = chi.copy()\n",
    "        new_pi = pi.copy()\n",
    "        for x in range(N):\n",
    "            # Calculate the vector of possible probabilities \n",
    "            workprob = P.dot(np.diag(M.dot(y_basis[Y[t]]))).dot(x_basis[x].T)*pi\n",
    "            #Find the maximizer\n",
    "            tx = np.argmax(workprob)\n",
    "            #Update the best path and the probabilities\n",
    "            new_chi[x] = chi[tx]+[x]\n",
    "            new_pi[x] = max(workprob)\n",
    "        chi = new_chi\n",
    "        pi = new_pi\n",
    "        #Renormalize probabilities for stability\n",
    "        pi = pi/sum(pi)\n",
    "\n",
    "    return np.array(chi[np.argmax(pi)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1f271c34f10>]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit_x = Viterbi_estimator(Y, mu_0, M, P)\n",
    "plt.plot(X)\n",
    "plt.plot(fit_x*0.9+0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
