use crate::{Error, Readable, Result};
use arrayref::array_ref;
pub struct Reader<'a> {
b: &'a [u8],
off: usize,
}
impl<'a> Reader<'a> {
pub fn from_slice(slice: &'a [u8]) -> Self {
Reader { b: slice, off: 0 }
}
pub fn from_bytes(b: &'a bytes::Bytes) -> Self {
Self::from_slice(b.as_ref())
}
pub fn total_len(&self) -> usize {
self.b.len()
}
pub fn remaining(&self) -> usize {
self.b.len() - self.off
}
pub fn into_rest(self) -> &'a [u8] {
&self.b[self.off..]
}
pub fn consumed(&self) -> usize {
self.off
}
pub fn advance(&mut self, n: usize) -> Result<()> {
if n > self.remaining() {
return Err(Error::Truncated);
}
self.off += n;
Ok(())
}
pub fn should_be_exhausted(&self) -> Result<()> {
if self.remaining() != 0 {
return Err(Error::ExtraneousBytes);
}
Ok(())
}
pub fn truncate(&mut self, n: usize) {
if n < self.remaining() {
self.b = &self.b[..self.off + n];
}
}
pub fn peek(&self, n: usize) -> Result<&'a [u8]> {
if self.remaining() < n {
return Err(Error::Truncated);
}
Ok(&self.b[self.off..(n + self.off)])
}
pub fn take(&mut self, n: usize) -> Result<&'a [u8]> {
let b = self.peek(n)?;
self.advance(n)?;
Ok(b)
}
pub fn take_into(&mut self, buf: &mut [u8]) -> Result<()> {
let n = buf.len();
let b = self.take(n)?;
buf.copy_from_slice(b);
Ok(())
}
pub fn take_u8(&mut self) -> Result<u8> {
let b = self.take(1)?;
Ok(b[0])
}
pub fn take_u16(&mut self) -> Result<u16> {
let b = self.take(2)?;
let r = u16::from_be_bytes(*array_ref![b, 0, 2]);
Ok(r)
}
pub fn take_u32(&mut self) -> Result<u32> {
let b = self.take(4)?;
let r = u32::from_be_bytes(*array_ref![b, 0, 4]);
Ok(r)
}
pub fn take_u64(&mut self) -> Result<u64> {
let b = self.take(8)?;
let r = u64::from_be_bytes(*array_ref![b, 0, 8]);
Ok(r)
}
pub fn take_u128(&mut self) -> Result<u128> {
let b = self.take(16)?;
let r = u128::from_be_bytes(*array_ref![b, 0, 16]);
Ok(r)
}
pub fn take_until(&mut self, term: u8) -> Result<&'a [u8]> {
let pos = self.b[self.off..]
.iter()
.position(|b| *b == term)
.ok_or(Error::Truncated)?;
let result = self.take(pos)?;
self.advance(1)?;
Ok(result)
}
pub fn take_rest(&mut self) -> &'a [u8] {
self.take(self.remaining())
.expect("taking remaining failed")
}
pub fn extract<E: Readable>(&mut self) -> Result<E> {
let off_orig = self.off;
let result = E::take_from(self);
if result.is_err() {
self.off = off_orig;
}
result
}
pub fn extract_n<E: Readable>(&mut self, n: usize) -> Result<Vec<E>> {
let mut result = Vec::with_capacity(n);
let off_orig = self.off;
for _ in 0..n {
match E::take_from(self) {
Ok(item) => result.push(item),
Err(e) => {
self.off = off_orig;
return Err(e);
}
}
}
Ok(result)
}
pub fn read_nested_u8len<F, T>(&mut self, f: F) -> Result<T>
where
F: FnOnce(&mut Reader) -> Result<T>,
{
read_nested_generic::<u8, _, _>(self, f)
}
pub fn read_nested_u16len<F, T>(&mut self, f: F) -> Result<T>
where
F: FnOnce(&mut Reader) -> Result<T>,
{
read_nested_generic::<u16, _, _>(self, f)
}
pub fn read_nested_u32len<F, T>(&mut self, f: F) -> Result<T>
where
F: FnOnce(&mut Reader) -> Result<T>,
{
read_nested_generic::<u32, _, _>(self, f)
}
}
fn read_nested_generic<L, F, T>(r: &mut Reader, f: F) -> Result<T>
where
F: FnOnce(&mut Reader) -> Result<T>,
L: Readable + Copy + Sized + TryInto<usize>,
{
let length: L = r.extract()?;
let length: usize = length.try_into().map_err(|_| Error::BadLengthValue)?;
let slice = r.take(length)?;
let mut inner = Reader::from_slice(slice);
let out = f(&mut inner)?;
inner.should_be_exhausted()?;
Ok(out)
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
#![allow(clippy::cognitive_complexity)]
use super::*;
#[test]
fn bytecursor_read_ok() {
let bytes = b"On a mountain halfway between Reno and Rome";
let mut bc = Reader::from_slice(&bytes[..]);
assert_eq!(bc.consumed(), 0);
assert_eq!(bc.remaining(), 43);
assert_eq!(bc.total_len(), 43);
assert_eq!(bc.take(3).unwrap(), &b"On "[..]);
assert_eq!(bc.consumed(), 3);
assert_eq!(bc.take_u16().unwrap(), 0x6120);
assert_eq!(bc.take_u8().unwrap(), 0x6d);
assert_eq!(bc.take_u64().unwrap(), 0x6f756e7461696e20);
assert_eq!(bc.take_u32().unwrap(), 0x68616c66);
assert_eq!(bc.consumed(), 18);
assert_eq!(bc.remaining(), 25);
assert_eq!(bc.total_len(), 43);
assert_eq!(bc.peek(7).unwrap(), &b"way bet"[..]);
assert_eq!(bc.consumed(), 18); assert_eq!(bc.remaining(), 25); assert_eq!(bc.total_len(), 43);
assert_eq!(bc.peek(7).unwrap(), &b"way bet"[..]);
assert_eq!(bc.consumed(), 18);
bc.advance(12).unwrap();
assert_eq!(bc.consumed(), 30);
assert_eq!(bc.remaining(), 13);
let rem = bc.into_rest();
assert_eq!(rem, &b"Reno and Rome"[..]);
let mut bc = Reader::from_slice(&bytes[..]);
bc.advance(22).unwrap();
assert_eq!(bc.remaining(), 21);
let rem = bc.take(21).unwrap();
assert_eq!(rem, &b"between Reno and Rome"[..]);
assert_eq!(bc.consumed(), 43);
assert_eq!(bc.remaining(), 0);
assert_eq!(bc.take(0).unwrap(), &b""[..]);
}
#[test]
fn read_u128() {
let bytes = bytes::Bytes::from(&b"irreproducibility?"[..]); let mut r = Reader::from_bytes(&bytes);
assert_eq!(r.take_u8().unwrap(), b'i');
assert_eq!(r.take_u128().unwrap(), 0x72726570726f6475636962696c697479);
assert_eq!(r.remaining(), 1);
}
#[test]
fn bytecursor_read_missing() {
let bytes = b"1234567";
let mut bc = Reader::from_slice(&bytes[..]);
assert_eq!(bc.consumed(), 0);
assert_eq!(bc.remaining(), 7);
assert_eq!(bc.total_len(), 7);
assert_eq!(bc.take_u64(), Err(Error::Truncated));
assert_eq!(bc.take(8), Err(Error::Truncated));
assert_eq!(bc.peek(8), Err(Error::Truncated));
assert_eq!(bc.consumed(), 0);
assert_eq!(bc.remaining(), 7);
assert_eq!(bc.total_len(), 7);
assert_eq!(bc.take_u32().unwrap(), 0x31323334); assert_eq!(bc.take_u32(), Err(Error::Truncated));
assert_eq!(bc.consumed(), 4);
assert_eq!(bc.remaining(), 3);
assert_eq!(bc.total_len(), 7);
assert_eq!(bc.take_u16().unwrap(), 0x3536); assert_eq!(bc.take_u16(), Err(Error::Truncated));
assert_eq!(bc.consumed(), 6);
assert_eq!(bc.remaining(), 1);
assert_eq!(bc.total_len(), 7);
assert_eq!(bc.take_u8().unwrap(), 0x37); assert_eq!(bc.take_u8(), Err(Error::Truncated));
assert_eq!(bc.consumed(), 7);
assert_eq!(bc.remaining(), 0);
assert_eq!(bc.total_len(), 7);
}
#[test]
fn advance_too_far() {
let bytes = b"12345";
let mut r = Reader::from_slice(&bytes[..]);
assert_eq!(r.remaining(), 5);
assert_eq!(r.advance(6), Err(Error::Truncated));
assert_eq!(r.remaining(), 5);
assert_eq!(r.advance(5), Ok(()));
assert_eq!(r.remaining(), 0);
}
#[test]
fn truncate() {
let bytes = b"Hello universe!!!1!";
let mut r = Reader::from_slice(&bytes[..]);
assert_eq!(r.take(5).unwrap(), &b"Hello"[..]);
assert_eq!(r.remaining(), 14);
assert_eq!(r.consumed(), 5);
r.truncate(9);
assert_eq!(r.remaining(), 9);
assert_eq!(r.consumed(), 5);
assert_eq!(r.take_u8().unwrap(), 0x20);
assert_eq!(r.into_rest(), &b"universe"[..]);
}
#[test]
fn exhaust() {
let r = Reader::from_slice(&b""[..]);
assert_eq!(r.should_be_exhausted(), Ok(()));
let mut r = Reader::from_slice(&b"outis"[..]);
assert_eq!(r.should_be_exhausted(), Err(Error::ExtraneousBytes));
r.take(4).unwrap();
assert_eq!(r.should_be_exhausted(), Err(Error::ExtraneousBytes));
r.take(1).unwrap();
assert_eq!(r.should_be_exhausted(), Ok(()));
}
#[test]
fn take_rest() {
let mut r = Reader::from_slice(b"si vales valeo");
assert_eq!(r.take(3).unwrap(), b"si ");
assert_eq!(r.take_rest(), b"vales valeo");
assert_eq!(r.take_rest(), b"");
}
#[test]
fn take_until() {
let mut r = Reader::from_slice(&b"si vales valeo"[..]);
assert_eq!(r.take_until(b' ').unwrap(), &b"si"[..]);
assert_eq!(r.take_until(b' ').unwrap(), &b"vales"[..]);
assert_eq!(r.take_until(b' '), Err(Error::Truncated));
}
#[test]
fn truncate_badly() {
let mut r = Reader::from_slice(&b"abcdefg"[..]);
r.truncate(1000);
assert_eq!(r.total_len(), 7);
assert_eq!(r.remaining(), 7);
}
#[test]
fn nested_good() {
let mut r = Reader::from_slice(b"abc\0\0\x04defghijkl");
assert_eq!(r.take(3).unwrap(), b"abc");
r.read_nested_u16len(|s| {
assert!(s.should_be_exhausted().is_ok());
Ok(())
})
.unwrap();
r.read_nested_u8len(|s| {
assert_eq!(s.take(4).unwrap(), b"defg");
assert!(s.should_be_exhausted().is_ok());
Ok(())
})
.unwrap();
assert_eq!(r.take(2).unwrap(), b"hi");
}
#[test]
fn nested_bad() {
let mut r = Reader::from_slice(b"................");
assert_eq!(
read_nested_generic::<u128, _, ()>(&mut r, |_| panic!())
.err()
.unwrap(),
Error::BadLengthValue
);
let mut r = Reader::from_slice(b"................");
assert_eq!(
r.read_nested_u32len::<_, ()>(|_| panic!()).err().unwrap(),
Error::Truncated
);
}
#[test]
fn extract() {
#[derive(Debug)]
struct LenEnc(Vec<u8>);
impl Readable for LenEnc {
fn take_from(b: &mut Reader<'_>) -> Result<Self> {
let length = b.take_u8()?;
let content = b.take(length as usize)?.into();
Ok(LenEnc(content))
}
}
let bytes = b"\x04this\x02is\x09sometimes\x01a\x06string!";
let mut r = Reader::from_slice(&bytes[..]);
let le: LenEnc = r.extract().unwrap();
assert_eq!(&le.0[..], &b"this"[..]);
let les: Vec<LenEnc> = r.extract_n(4).unwrap();
assert_eq!(&les[3].0[..], &b"string"[..]);
assert_eq!(r.remaining(), 1);
let le: Result<LenEnc> = r.extract();
assert_eq!(le.unwrap_err(), Error::Truncated);
assert_eq!(r.remaining(), 1);
let mut r = Reader::from_slice(&bytes[..]);
assert_eq!(r.remaining(), 28);
let les: Result<Vec<LenEnc>> = r.extract_n(10);
assert_eq!(les.unwrap_err(), Error::Truncated);
assert_eq!(r.remaining(), 28);
}
}