{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import numpy as np\n",
    "from scipy.special import xlogy\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy import log2\n",
    "import numpy as np\n",
    "import bitstring as bt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('single_counts.csv', index_col=0)['Count']\n",
    "data_dict = {(x,): data[x] for x in data.index}              \n",
    " #Efficiency suggests removing codewords with zero counts, but this may make some messages impossible to encode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shannon_code(data, eps=1e-6):\n",
    "    # Compute the Shannon code using the count data\n",
    "    # Introduces a small tolerance eps in order to correct for zero probability entries\n",
    "    \n",
    "    probs= data.sort_values(ascending=False)+eps    #Calculate the probabilities, with a small tolerance to account for zero entries\n",
    "    probs = probs/probs.sum()\n",
    "\n",
    "    cumulative_probs=[0,*probs.cumsum()[:-1]]                      #Calculate the corresponding real values\n",
    "    lengths = [int(np.ceil(-log2(x))) for x in probs]       #Calculate the lengths of codewords\n",
    "\n",
    "    binary_cum_probs = [bt.Bits(float=1+x,length=32)[9:] for x in cumulative_probs]\n",
    "                                                            #Convert probabilities into binary representations\n",
    "    codewords = [binary_cum_probs[i][:lengths[i]] for i in range(len(lengths))]\n",
    "                                                            #Calculate the appropiate codeword\n",
    "\n",
    "    return {probs.index[i]:codewords[i] for i in range(len(lengths))} #Return the dictionary of codewords\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shannon_codewords = shannon_code(data)\n",
    "print({x:shannon_codewords[x].bin for x in shannon_codewords})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Implements the Huffman algorithm to construct an optimal code\n",
    "#Not the most efficient implementation (due to the memory management of dicts)\n",
    "\n",
    "def huffman(data_dict):\n",
    "    working_dict = data_dict.copy()  # dict of tree segments that have already been processed\n",
    "    codewords = {x[0]:bt.Bits(bin='') for x in data_dict.keys()} # dict of codewords for each input symbol\n",
    "\n",
    "    while(len(working_dict)>1):\n",
    "        lowest_ind= min(working_dict, key=working_dict.get)  # find entry with smallest counts\n",
    "        lowest_count = working_dict.pop(lowest_ind)          # remove entry from dict\n",
    "        second_ind = min(working_dict, key=working_dict.get) # find entry with (second) smallest counts\n",
    "        second_count = working_dict.pop(second_ind)          # remove entry from dict\n",
    "\n",
    "        merge_ind = lowest_ind + second_ind                  # build concatenated codeword\n",
    "        merge_count = lowest_count+ second_count             # compute concatented probability\n",
    "        working_dict[merge_ind] = merge_count                # insert new codeword into dict\n",
    "\n",
    "        # Add prefixes to the codewords corresponding to the newly combined tree segments\n",
    "        for x in lowest_ind:\n",
    "            codewords[x] = bt.Bits(bin='1')+codewords[x] #Notice we prefix - this gives an instantaneous code\n",
    "        for x in second_ind:\n",
    "            codewords[x] = bt.Bits(bin='0')+codewords[x]\n",
    "    return(codewords)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "huffman_codewords = huffman(data_dict)\n",
    "print({x:huffman_codewords[x].bin for x in huffman_codewords})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "average_length = sum([data[x]*len(shannon_codewords[x].bin) for x in shannon_codewords])/sum(data)\n",
    "print('Shannon: ' + str(average_length))\n",
    "average_length = sum([data[x]*len(huffman_codewords[x].bin) for x in huffman_codewords])/sum(data)\n",
    "print('Huffman: ' + str(average_length))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encoder(message, codewords):\n",
    "    #Applies a fixed-to-variable dictionary code, assuming it has fixed length blocks\n",
    "    #Pads with space if necessary\n",
    "    #Inefficient use of memory, but easy to follow\n",
    "    #returns both coded message and message, to see if padded\n",
    "    \n",
    "    output = bt.Bits(bin='')                    #Initialize output string\n",
    "    block_len= len(list(codewords.keys())[0])   #Compute block length\n",
    "    initial_message = message+\"\"                #Store a copy of input string\n",
    "    while(len(message)>=block_len):             #Iterate until input string is processed\n",
    "        block = message[:block_len]             #Compute the current block\n",
    "        output = output + codewords[block]      #Add next block to output\n",
    "        message = message[block_len:]           #Drop the block that has been processed\n",
    "    if message == '':                       #If all message has been processed, send to output\n",
    "        return (output, initial_message)\n",
    "    else:                                   #Otherwise, pad with enough spaces to give a final block\n",
    "        return (output + codewords[message+'_'*(block_len-len(message))], initial_message+'_'*(block_len-len(message))) \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = 'THE_RAIN_IN_SPAIN'\n",
    "huffman_coded_message, _ = encoder(message, huffman_codewords)\n",
    "shannon_coded_message, _ = encoder(message, shannon_codewords)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "huffman_coded_message.bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#ASCII coding\n",
    "''.join(format(ord(x), 'b') for x in message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decoder(coded_message, codewords):\n",
    "    #Decodes using a variable-to-fixed code\n",
    "\n",
    "    decode_dict = {codewords[x]:x for x in codewords}   #Construct the reversed dictionary of codewords\n",
    "    output = \"\"                                         #Initialize output\n",
    "    max_codeword_len = max(len(codewords[x]) for x in codewords) #Find length of longest codeword (for error checking)\n",
    "    while(len(coded_message)>0):                        #Continue until all message processed\n",
    "        code_len = 1\n",
    "        while(code_len>0):                              #Iterate through the string until we find a codeword\n",
    "            assert(code_len <= max_codeword_len)        #Throw an error if code_len is too long...\n",
    "            test_word= coded_message[:code_len]\n",
    "            if test_word in decode_dict:                #If we've found a codeword\n",
    "                output = output+decode_dict[test_word]  #Decode this codeword\n",
    "                coded_message = coded_message[code_len:]    #Drop this codeword and continue\n",
    "                code_len = 0\n",
    "            else:\n",
    "                code_len = code_len+1                   #If we've not found a codeword, look one term longer\n",
    "\n",
    "    return(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder(shannon_coded_message, shannon_codewords)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder(huffman_coded_message, huffman_codewords)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_location = 1\n",
    "error_message = list(huffman_coded_message.bin)\n",
    "error_message[error_location] = str((int(error_message[error_location])+1)%2)\n",
    "error_message = bt.Bits(bin=''.join(error_message))\n",
    "decoder(error_message, huffman_codewords)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
