Source code for pfp.bitwrap

#!/usr/bin/env python
# encoding: utf-8

import collections
from intervaltree import IntervalTree, Interval
import math
import os
import six
import sys

import pfp.utils as utils


[docs]class EOFError(Exception): pass
[docs]def bits_to_bytes(bits): """Convert the bit list into bytes. (Assumes bits is a list whose length is a multiple of 8) """ if len(bits) % 8 != 0: raise Exception("num bits must be multiple of 8") res = "" for x in six.moves.range(0, len(bits), 8): byte_bits = bits[x : x + 8] byte_val = int("".join(map(str, byte_bits)), 2) res += chr(byte_val) return utils.binary(res)
[docs]def bytes_to_bits(bytes_): """Convert bytes to a list of bits """ res = [] for x in bytes_: if not isinstance(x, int): x = ord(x) res += byte_to_bits(x) return res
[docs]def byte_to_bits(b): """Convert a byte into bits """ return [(b >> x) & 1 for x in six.moves.range(7, -1, -1)]
[docs]class BitwrappedStream(object): """A stream that wraps other streams to provide bit-level access""" closed = True def __init__(self, stream): """Init the bit-wrapped stream :stream: The normal byte stream """ self._stream = stream self._bits = collections.deque() self.closed = False # assume that bitfields end on an even boundary, # otherwise the entire stream will be treated as # a bit stream with no padding self.padded = True self.range_set = IntervalTree()
[docs] def is_eof(self): """Return if the stream has reached EOF or not without discarding any unflushed bits :returns: True/False """ pos = self._stream.tell() byte = self._stream.read(1) self._stream.seek(pos, 0) return utils.binary(byte) == utils.binary("")
[docs] def close(self): """Close the stream """ self.closed = True self._flush_bits_to_stream() self._stream.close()
[docs] def flush(self): """Flush the stream """ self._flush_bits_to_stream() self._stream.flush()
[docs] def isatty(self): """Return if the stream is a tty """ return self._stream.isatty()
[docs] def read(self, num): """Read ``num`` number of bytes from the stream. Note that this will automatically resets/ends the current bit-reading if it does not end on an even byte AND ``self.padded`` is True. If ``self.padded`` is True, then the entire stream is treated as a bitstream. :num: number of bytes to read :returns: the read bytes, or empty string if EOF has been reached """ start_pos = self.tell() if self.padded: # we toss out any uneven bytes self._bits.clear() res = utils.binary(self._stream.read(num)) else: bits = self.read_bits(num * 8) res = bits_to_bytes(bits) res = utils.binary(res) end_pos = self.tell() self._update_consumed_ranges(start_pos, end_pos) return res
[docs] def read_bits(self, num): """Read ``num`` number of bits from the stream :num: number of bits to read :returns: a list of ``num`` bits, or an empty list if EOF has been reached """ if num > len(self._bits): needed = num - len(self._bits) num_bytes = int(math.ceil(needed / 8.0)) read_bytes = self._stream.read(num_bytes) for bit in bytes_to_bits(read_bytes): self._bits.append(bit) res = [] while len(res) < num and len(self._bits) > 0: res.append(self._bits.popleft()) return res
[docs] def write(self, data): """Write data to the stream :data: the data to write to the stream :returns: None """ if self.padded: # flush out any remaining bits first if len(self._bits) > 0: self._flush_bits_to_stream() self._stream.write(data) else: # nothing to do here if len(data) == 0: return bits = bytes_to_bits(data) self.write_bits(bits)
[docs] def write_bits(self, bits): """Write the bits to the stream. Add the bits to the existing unflushed bits and write complete bytes to the stream. """ for bit in bits: self._bits.append(bit) while len(self._bits) >= 8: byte_bits = [self._bits.popleft() for x in six.moves.range(8)] byte = bits_to_bytes(byte_bits) self._stream.write(byte)
# there may be unflushed bits leftover and THAT'S OKAY
[docs] def tell(self): """Return the current position in the stream (ignoring bit position) :returns: int for the position in the stream """ res = self._stream.tell() if len(self._bits) > 0: res -= 1 return res
[docs] def tell_bits(self): """Return the number of bits into the stream since the last whole byte. :returns: int """ if len(self._bits) == 0: return 0 return 8 - len(self._bits)
[docs] def seek(self, pos, seek_type=0): """Seek to the specified position in the stream with seek_type. Unflushed bits will be discarded in the case of a seek. The stream will also keep track of which bytes have and have not been consumed so that the dom will capture all of the bytes in the stream. :pos: offset :seek_type: direction :returns: TODO """ self._bits.clear() return self._stream.seek(pos, seek_type)
[docs] def size(self): """Return the size of the stream, or -1 if it cannot be determined. """ pos = self._stream.tell() # seek to the end of the stream self._stream.seek(0, 2) size = self._stream.tell() self._stream.seek(pos, 0) return size
[docs] def unconsumed_ranges(self): """Return an IntervalTree of unconsumed ranges, of the format (start, end] with the end value not being included """ res = IntervalTree() prev = None # normal iteration is not in a predictable order ranges = sorted([x for x in self.range_set], key=lambda x: x.begin) for rng in ranges: if prev is None: prev = rng continue res.add(Interval(prev.end, rng.begin)) prev = rng # means we've seeked past the end if len(self.range_set[self.tell()]) != 1: res.add(Interval(prev.end, self.tell())) return res
# ----------------------------- # PRIVATE FUNCTIONS # ----------------------------- def _update_consumed_ranges(self, start_pos, end_pos): """Update the ``self.consumed_ranges`` array with which byte ranges have been consecutively consumed. """ self.range_set.add(Interval(start_pos, end_pos + 1)) self.range_set.merge_overlaps() def _flush_bits_to_stream(self): """Flush the bits to the stream. This is used when a few bits have been read and ``self._bits`` contains unconsumed/ flushed bits when data is to be written to the stream """ if len(self._bits) == 0: return 0 bits = list(self._bits) diff = 8 - (len(bits) % 8) padding = [0] * diff bits = bits + padding self._stream.write(bits_to_bytes(bits)) self._bits.clear()