Sunday, October 9, 2011

Huffman coding done in Scala

Huffman coding is an entropy based encoding algorithm. Take an arbitrary language, and you will find that certain characters are used way more often than others. Huffman coding allows you to leverage that; it allows you to use less bits for characters that occur more often, and more bits for the exceptions.

Having Huffman encoding for Preon has been on my wish list for a while, but since all I do is Scala these days, I just couldn't resist trying to do it in Scala. And it works:

class DecodingException(depth: Int) extends Exception

class HuffmanCodec(in: Map[Char,Double]) {

  private val tree = {
    val list = { for ((char, weight) <- in) yield CharacterNode(char, weight) } toList
    def createTree(list: List[Node]): Node = list match {
      case top :: Nil => top
      case first :: second :: Nil => ChoiceNode(first.push(true), second.push(false))
      case multiple => 
        val first :: second :: others = multiple.sortBy(_.weight)
        createTree(ChoiceNode(first.push(true), second.push(false)) :: others)
    }
    createTree(list)
  }
  
  private val index = {
    tree.all.collect{ 
      case node: CharacterNode => (node.char, node.code)
    }.toMap
  }

  def decode(in: Seq[Boolean], length: Int, builder: StringBuilder = new StringBuilder): (String, Seq[Boolean]) = {
    if (length > 0) {
      val (char, remaining) = tree.decode(in)
      builder += char
      decode(remaining, length - 1, builder)
    } else {
      (builder.result, in)
    }
  }

  def encode(in: String): Seq[Boolean] = {
    in.flatMap(index(_))
  }
  
  private abstract sealed class Node {
    
    def decode(in: Seq[Boolean]): (Char, Seq[Boolean])

    def weight: Double
    
    def push(code: Boolean): Node
    
    def all: List[Node]
    
  }

  private case class CharacterNode(char: Char, weight: Double, code: List[Boolean] = Nil) extends Node {
    
    def decode(in: Seq[Boolean]): (Char, Seq[Boolean]) = (char, in)
      
    def push(code: Boolean) = CharacterNode(char, weight, code :: this.code)
    
    val all = List(this)
    
    override def toString = "CharacterNode(%s)".format(char)

  }

  private case class ChoiceNode(left: Node, right: Node, code: List[Boolean] = Nil) extends Node {
    
    val weight = left.weight + right.weight

    def push(code: Boolean) = ChoiceNode(left.push(code), right.push(code), code :: this.code)

    def decode(in: Seq[Boolean]): (Char, Seq[Boolean]) = in.headOption match {
      case Some(true) => left.decode(in.tail)
      case Some(false) => right.decode(in.tail)
      case None => throw new DecodingException(code.length)
    }

    val all = this :: (left.all ++ right.all)
    
    override def toString = "ChoiceNode(%s, %s)".format(left, right)

  }

}

With these definitions, creating a new instance of a Codec is as easy this:

scala> val codec = new HuffmanCodec(Map('a' -> 2.0, 'b' -> 3.0, 'c' -> 0.5, 'd' -> 0.7))
codec: HuffmanCodec = HuffmanCodec@52c54b3b

In the above case, the only characters we can encode are characters 'a' to 'd'. They are passed in as a map, with the character as the key and the weight as the value. (The weight says something about the relative presence of that particular character compared to all of the other characters. You could normalize everything to be fractions on a scale running from 0 to 1 that would all add up to 1, but as long as the numbers correctly reflect relative appearance, it's all good.)

The underlying tree that is getting constructed based on these arguments looks like this:

That means you can encode the 'b' character with just a single bit, and the a character with just two bits:

scala> codec.encode("b")
res3: Seq[Boolean] = Vector(false)

scala> codec.encode("a")
res4: Seq[Boolean] = Vector(true, false)

As you can see, the values are encoded as booleans for representing bit values. That's certainly not ideal, but the fastest way to get something working for now. (Ideally, bit should be slammed on top of some existing BitBuffer/BitStream abstractions.)

If you want to decode something, then in its current incarnation, you should pass in a the number of characters you expect to be encoded by the Sequence of Booleans you passed in.

scala> codec.decode(Seq(true, false), 1)
res8: (String, Seq[Boolean]) = (a,List())

0 comments: