{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import numpy as np\n",
    "from scipy.special import xlogy\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy import log2\n",
    "import bitstring as bt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('single_counts.csv', index_col=0)['Count']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs = (data/sum(data)).sort_values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_len = 12 # As we use 64 bit integers, this needs to be at most 62"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def _BAC_splitter(K, probs, probs_sorted = True):\n",
    "    #Uses Boncelet's 1993 heuristic rule for splitting K messages into subsets using the probabilities\n",
    "    if not probs_sorted:\n",
    "        probs = probs.sort_values()\n",
    "\n",
    "    m = len(probs)\n",
    "    L = int(np.floor((K-1)/(m-1)))\n",
    "    tL = L\n",
    "    q=1.\n",
    "    #Construct Li series, to preserve labelling in pandas\n",
    "    Li= pd.Series(index = probs.index, dtype = np.int64, data = np.ones(len(probs))) \n",
    "\n",
    "    for i in range(len(probs)):\n",
    "        tp = probs[i]/q\n",
    "        Li[i] = np.int64(max(0, np.floor(tp*tL + (tp*(1-i)-1)/(m-1) +0.5))) #Slightly different to in lectures, to account for indices starting at 0\n",
    "        tL = tL - Li[i]\n",
    "        q = q-probs[i]\n",
    "    return 1+Li*(m-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BAC_codebook_builder(output_len, probs):\n",
    "    #Builds the BAC codebook using Boncelet's heuristic\n",
    "    #Does not give the same codebook as the on-the-fly encoding\n",
    "    #Moderately slow\n",
    "    \n",
    "    K= 2**output_len\n",
    "    working_blocks = {('',K)}\n",
    "    final_codewords = set()\n",
    "\n",
    "    probs = probs.sort_values()\n",
    "\n",
    "    while len(working_blocks)>0:\n",
    "        current_block = working_blocks.pop()\n",
    "        split = _BAC_splitter(current_block[1], probs)\n",
    "        new_blocks = {(current_block[0]+x, split[x]) for x in split.index if split[x]>1}\n",
    "        working_blocks = working_blocks.union(new_blocks)\n",
    "        new_words = {current_block[0]+x for x in split.index if split[x]==1}\n",
    "        final_codewords = final_codewords.union(new_words)\n",
    "\n",
    "    final_codewords = list(final_codewords)\n",
    "    return {final_codewords[x]:bt.Bits(uint=x, length =output_len) for x in range(len(final_codewords))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Codebook contains 4083 codewords, out of a maximum of 4096\n",
      "Suggests an inefficiency of 0.004586147774849891 bits per block of 12 bits\n"
     ]
    }
   ],
   "source": [
    "codewords= BAC_codebook_builder(output_len, probs) #Calculate the codebook\n",
    "print('Codebook contains ' + str(len(codewords)) + ' codewords, out of a maximum of ' + str(2**output_len))\n",
    "print('Suggests an inefficiency of ' + str(output_len - log2(len(codewords)))+ ' bits per block of ' +str(output_len) +' bits')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BAC_encoder(message, codewords):\n",
    "    # Applies the BAC code\n",
    "    # Pads with spaces if necessary\n",
    "    # Inefficient use of memory, but easy to follow\n",
    "\n",
    "    # Compute max and min input codeword lengths\n",
    "    min_block_len = min(len(x) for x in codewords.keys()) \n",
    "    max_block_len = max(len(x) for x in codewords.keys())\n",
    "\n",
    "    # Set up padding variable\n",
    "    pad = '_'*(max_block_len)\n",
    "\n",
    "    # Set up working in and out-put variables\n",
    "    remaining_message = message+\"\"\n",
    "    output = bt.Bits(bin='')\n",
    "    \n",
    "    while len(remaining_message) >0:\n",
    "        # Iterate until the whole message is processed\n",
    "\n",
    "        block_len = min_block_len\n",
    "        increase_block = True\n",
    "        while increase_block:\n",
    "            # Iterate through increasing input message lengths\n",
    "\n",
    "            # If the message being processed is too short, pad it\n",
    "            if len(remaining_message)>=block_len:\n",
    "                message_block = remaining_message[:block_len]\n",
    "            else:\n",
    "                message_block = (remaining_message+pad)[:block_len]\n",
    "\n",
    "            # Check if input message is in codewords\n",
    "            if message_block in codewords.keys():\n",
    "                # If input is in codewords, add corresponding output and move to next section\n",
    "                output = output + codewords[message_block]\n",
    "                remaining_message = remaining_message[block_len:]\n",
    "                increase_block=False\n",
    "            else:\n",
    "                block_len = block_len+1\n",
    "\n",
    "            # If the block length has gotten too long, then throw an error    \n",
    "            assert(block_len<=max_block_len)\n",
    "    return(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BAC_decoder(coded_message, codewords):\n",
    "    # Decodes a message encoded using BAC\n",
    "    # Similar to Huffman decoder, but fixed-length coded messages makes this easier\n",
    "    \n",
    "    # Construct reverse dictionary\n",
    "    decode_dict = {codewords[x]:x for x in codewords}\n",
    "    \n",
    "    # Compute the size of a block from an example codeword\n",
    "    block_len = len(codewords[sorted(codewords.keys())[0]])\n",
    "    \n",
    "    # If the coded message doesn't break into blocks, then throw an error\n",
    "    assert(len(coded_message)% block_len == 0)\n",
    "\n",
    "    # Decode each block and output\n",
    "    output = [decode_dict[coded_message[(block_len*i):(block_len*(i+1))]] for i in range(len(coded_message)//block_len)]\n",
    "    return ''.join(output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = 'THE_RAIN_IN_SPAIN_FALLS_MAINLY_ON_THE_PLAIN'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "011010101000101101111101100111001010100111001010010101011011011011001011110001110101110111001100110100011010011011001011011100111010001010110011011010001100101010100101110001000110000011001100\n",
      "THE_RAIN_IN_SPAIN_FALLS_MAINLY_ON_THE_PLAIN__\n"
     ]
    }
   ],
   "source": [
    "coded_message = BAC_encoder(message, codewords)\n",
    "print(coded_message.bin)\n",
    "print(BAC_decoder(coded_message, codewords))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BAC_encoder_fly(message, output_len, probs):\n",
    "    # Implements Boncelet's BAC encoder in an on-the-fly manner (no codebook saved)\n",
    "    # Computes a number corresponding to each input message, then sends this number in binary\n",
    "\n",
    "    # Compute total number of outputs and sorts probabilities\n",
    "    K= 2**output_len\n",
    "    probs = probs.sort_values()\n",
    "    \n",
    "\n",
    "    output=bt.Bits()\n",
    "    while len(message)>0:\n",
    "        # Iterate through the input message\n",
    "\n",
    "        # Initialize variables \n",
    "        block_len = 1    # Size of message block\n",
    "        cur_K = K        # Number of potential outputs for this block\n",
    "        cur_sum_K = 0    # Number of outputs excluded before current iteration\n",
    "\n",
    "        increase_block = True\n",
    "        while increase_block:\n",
    "            # Iterate through increasing block sizes\n",
    "\n",
    "            # If message is too short, then pad with '_'\n",
    "            if len(message)>=block_len:\n",
    "                cur_mess = message[:block_len]\n",
    "            else:\n",
    "                cur_mess = message + '_'*(block_len-len(message))\n",
    "\n",
    "            # Split the current set of potential outputs using the heuristic\n",
    "            split = _BAC_splitter(cur_K, probs)\n",
    "\n",
    "            # Compute the number of message possibilities for this block\n",
    "            cur_K = split[cur_mess[-1]]\n",
    "\n",
    "            # Compute how many messages there are above the current block's group\n",
    "            # This is needed so that we can keep track of the position of our final output message\n",
    "            cur_sum_K = cur_sum_K+split.cumsum().shift(fill_value=0)[cur_mess[-1]]\n",
    "\n",
    "            # If there is only one possible message for the current block, \n",
    "            # compute the number of this message (accounting for previous messages, output it \n",
    "            # and move to the next block. \n",
    "            if split[cur_mess[-1]]==1:\n",
    "                output = output+bt.Bits(uint = cur_sum_K + cur_K-1, length = output_len)\n",
    "                if len(message)>block_len:\n",
    "                    message = message[block_len:] # Remove processed block from message\n",
    "                else:\n",
    "                    message = '' # If our block was longer than the message (so we padded it), just make the message empty\n",
    "                increase_block=False\n",
    "            else:\n",
    "                block_len = block_len+1\n",
    "\n",
    "    return output\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _BAC_decode_single(coded_message, probs):\n",
    "    # Uses the BAC algorithm, in on-the-fly mode, to decode a single message block\n",
    "\n",
    "    # Set up variables\n",
    "    cur_K = 2**len(coded_message)    # Total number of possible messages\n",
    "    cur_message = coded_message.uint # Coded message expressed as an integer\n",
    "    cur_output=''       # Current partial output string\n",
    "    cur_output_loc = 0  # Current partial output value\n",
    "\n",
    "    while cur_output_loc+1<cur_message:\n",
    "        # Iterate until output value agrees with coded message value\n",
    "\n",
    "        # Calculate the BAC heuristic split of messages\n",
    "        split = _BAC_splitter(cur_K, probs)\n",
    "        split_sum = split.cumsum()\n",
    "\n",
    "        # Compute the next character of the message, by comparing with the values assigned to different partial strings\n",
    "        cur_char = split_sum[(split_sum>cur_message-cur_output_loc) \n",
    "                             &(split_sum.shift(fill_value=0)<=cur_message-cur_output_loc )].index[0]\n",
    "        \n",
    "        # Add next character to output\n",
    "        cur_output = cur_output + cur_char\n",
    "\n",
    "        # Compute number of messages following current partial string\n",
    "        cur_K = split[cur_char]\n",
    "\n",
    "        # Compute the value of the current partial string\n",
    "        cur_output_loc = cur_output_loc + split_sum.shift(fill_value=0)[cur_char]\n",
    "    return cur_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BAC_decoder_fly(coded_message, output_len, probs):\n",
    "    # Decodes a message encoded using the BAC algorithm, in on-the-fly mode (no codebook stored)\n",
    "\n",
    "    # Check probabilities are sorted\n",
    "    probs = probs.sort_values()\n",
    "\n",
    "    # Check coded_message can be cleanly split into fixed-length codeblocks, if not throw an error\n",
    "    assert(len(coded_message) % output_len ==0)\n",
    "\n",
    "    # Compute the decoding of each codeblock\n",
    "    output = [_BAC_decode_single(coded_message[i*output_len:(i+1)*output_len], probs) for i in range(len(coded_message)//output_len)]\n",
    "\n",
    "    # Concatenate decoded values and output\n",
    "    return ''.join(output)\n",
    "        \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "101001010001110111000100011010101001011010101001010101100010100101111110110100000111001011011100011000011110100101111110001011010110111010010111111011111110110010011110001011111000100000000001\n"
     ]
    }
   ],
   "source": [
    "coded_message_fly = BAC_encoder_fly(message, output_len, probs)\n",
    "print(coded_message_fly.bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'THE_RAIN_IN_SPAIN_FALLS_MAINLY_ON_THE_PLAIN__'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "BAC_decoder_fly(coded_message_fly, output_len, probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
