use std::convert::{TryFrom, TryInto};
use indexmap::IndexMap;
use llvm_bitstream::parser::StreamEntry;
use llvm_bitstream::record::{Block, Record};
use llvm_bitstream::Bitstream;
use llvm_support::bitcodes::IrBlockId;
use thiserror::Error;
use crate::block::{BlockId, BlockMapError, Identification, Module, Strtab, Symtab};
use crate::error::Error;
use crate::map::{PartialCtxMappable, PartialMapCtx};
use crate::record::{RecordBlobError, RecordStringError};
#[derive(Clone, Debug)]
pub struct UnrolledRecord(Record);
impl UnrolledRecord {
pub fn code(&self) -> u64 {
self.0.code
}
pub fn try_string(&self, idx: usize) -> Result<String, RecordStringError> {
if idx >= self.0.fields.len() - 1 {
return Err(RecordStringError::BadIndex(idx, self.0.fields.len()));
}
let raw = self.0.fields[idx..]
.iter()
.map(|f| u8::try_from(*f))
.collect::<Result<Vec<_>, _>>()?;
String::from_utf8(raw).map_err(RecordStringError::from)
}
pub fn try_blob(&self, idx: usize) -> Result<Vec<u8>, RecordBlobError> {
if idx >= self.0.fields.len() - 1 {
return Err(RecordBlobError::BadIndex(idx, self.0.fields.len()));
}
Ok(self.0.fields[idx..]
.iter()
.map(|f| u8::try_from(*f))
.collect::<Result<Vec<_>, _>>()?)
}
pub fn fields(&self) -> &[u64] {
&self.0.fields
}
}
#[derive(Debug, Error)]
pub enum ConsistencyError {
#[error("expected a block with {0:?} but not present")]
MissingBlock(BlockId),
#[error("expected exactly one block with {0:?} but got more than one")]
TooManyBlocks(BlockId),
#[error("expected a record of code {0} but not present")]
MissingRecord(u64),
#[error("expected exactly one record of code {0} but got more than one")]
TooManyRecords(u64),
}
#[derive(Clone, Debug, Default)]
pub struct UnrolledRecords(Vec<UnrolledRecord>);
impl UnrolledRecords {
pub(crate) fn by_code<'a>(
&'a self,
code: impl Into<u64> + 'a,
) -> impl Iterator<Item = &UnrolledRecord> + 'a {
let code = code.into();
self.0.iter().filter(move |r| r.code() == code)
}
pub(crate) fn one<'a>(&'a self, code: impl Into<u64> + 'a) -> Option<&UnrolledRecord> {
self.by_code(code).next()
}
pub(crate) fn exactly_one<'a>(
&'a self,
code: impl Into<u64> + 'a,
) -> Result<&UnrolledRecord, ConsistencyError> {
let code = code.into();
let mut records = self.by_code(code);
match records.next() {
None => Err(ConsistencyError::MissingRecord(code)),
Some(r) => match records.next() {
None => Ok(r),
Some(_) => Err(ConsistencyError::TooManyRecords(code)),
},
}
}
pub(crate) fn one_or_none<'a>(
&'a self,
code: impl Into<u64> + 'a,
) -> Result<Option<&UnrolledRecord>, ConsistencyError> {
let code = code.into();
let mut records = self.by_code(code);
match records.next() {
None => Ok(None),
Some(r) => match records.next() {
None => Ok(Some(r)),
Some(_) => Err(ConsistencyError::TooManyRecords(code)),
},
}
}
}
impl<'a> IntoIterator for &'a UnrolledRecords {
type Item = &'a UnrolledRecord;
type IntoIter = std::slice::Iter<'a, UnrolledRecord>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
#[derive(Clone, Debug, Default)]
pub struct UnrolledBlocks(IndexMap<BlockId, Vec<UnrolledBlock>>);
impl UnrolledBlocks {
pub(crate) fn by_id(&self, id: BlockId) -> impl Iterator<Item = &UnrolledBlock> + '_ {
self.0.get(&id).into_iter().flatten()
}
pub(crate) fn exactly_one(&self, id: BlockId) -> Result<&UnrolledBlock, ConsistencyError> {
let mut blocks = self.by_id(id);
match blocks.next() {
None => Err(ConsistencyError::MissingBlock(id)),
Some(b) => match blocks.next() {
None => Ok(b),
Some(_) => Err(ConsistencyError::TooManyBlocks(id)),
},
}
}
pub(crate) fn one_or_none(
&self,
id: BlockId,
) -> Result<Option<&UnrolledBlock>, ConsistencyError> {
let mut blocks = self.by_id(id);
match blocks.next() {
None => Ok(None),
Some(b) => match blocks.next() {
None => Ok(Some(b)),
Some(_) => Err(ConsistencyError::TooManyBlocks(id)),
},
}
}
}
#[derive(Clone, Debug)]
pub struct UnrolledBlock {
pub id: BlockId,
records: UnrolledRecords,
blocks: UnrolledBlocks,
}
impl UnrolledBlock {
pub(self) fn new(id: u64) -> Self {
Self {
id: id.into(),
records: UnrolledRecords::default(),
blocks: UnrolledBlocks::default(),
}
}
pub fn records(&self) -> &UnrolledRecords {
&self.records
}
pub fn blocks(&self) -> &UnrolledBlocks {
&self.blocks
}
}
#[derive(Debug)]
pub struct UnrolledBitcode {
pub modules: Vec<BitcodeModule>,
}
impl TryFrom<&[u8]> for UnrolledBitcode {
type Error = Error;
fn try_from(buf: &[u8]) -> Result<UnrolledBitcode, Self::Error> {
let (_, bitstream) = Bitstream::from(buf)?;
bitstream.try_into()
}
}
impl<T: AsRef<[u8]>> TryFrom<Bitstream<T>> for UnrolledBitcode {
type Error = Error;
fn try_from(mut bitstream: Bitstream<T>) -> Result<UnrolledBitcode, Self::Error> {
fn enter_block<T: AsRef<[u8]>>(
bitstream: &mut Bitstream<T>,
block: Block,
) -> Result<UnrolledBlock, Error> {
let mut unrolled_block = UnrolledBlock::new(block.block_id);
loop {
let entry = bitstream
.next()
.ok_or_else(|| Error::Unroll("unexpected stream end during unroll".into()))?;
match entry? {
StreamEntry::Record(record) => {
unrolled_block.records.0.push(UnrolledRecord(record))
}
StreamEntry::SubBlock(block) => {
let unrolled_child = enter_block(bitstream, block)?;
unrolled_block
.blocks
.0
.entry(unrolled_child.id)
.or_insert_with(Vec::new)
.push(unrolled_child);
}
StreamEntry::EndBlock => {
break;
}
}
}
Ok(unrolled_block)
}
let mut partial_modules = Vec::new();
loop {
let entry = bitstream.next();
if entry.is_none() {
break;
}
let top_block = {
#[allow(clippy::unwrap_used)]
let block = entry.unwrap()?.as_block().ok_or_else(|| {
Error::Unroll("bitstream has non-blocks at the top-level scope".into())
})?;
enter_block(&mut bitstream, block)?
};
match top_block.id {
BlockId::Ir(IrBlockId::Identification) => {
partial_modules.push(PartialBitcodeModule::new(top_block));
}
BlockId::Ir(IrBlockId::Module) => {
let last_partial = partial_modules.last_mut().ok_or_else(|| {
Error::Unroll("malformed bitstream: MODULE_BLOCK with no preceding IDENTIFICATION_BLOCK".into())
})?;
match &last_partial.module {
Some(_) => {
return Err(Error::Unroll(
"malformed bitstream: adjacent MODULE_BLOCKs".into(),
))
}
None => last_partial.module = Some(top_block),
}
}
BlockId::Ir(IrBlockId::Strtab) => {
for prev_partial in partial_modules
.iter_mut()
.rev()
.take_while(|p| p.strtab.is_none())
{
prev_partial.strtab = Some(top_block.clone());
}
}
BlockId::Ir(IrBlockId::Symtab) => {
for prev_partial in partial_modules
.iter_mut()
.rev()
.take_while(|p| p.symtab.is_none())
{
prev_partial.symtab = Some(top_block.clone());
}
}
_ => {
return Err(Error::Unroll(format!(
"unexpected top-level block: {:?}",
top_block.id
)))
}
}
}
let modules = partial_modules
.into_iter()
.map(|p| p.reify())
.collect::<Result<Vec<_>, _>>()?;
let unrolled = UnrolledBitcode { modules };
Ok(unrolled)
}
}
#[derive(Debug)]
struct PartialBitcodeModule {
identification: UnrolledBlock,
module: Option<UnrolledBlock>,
strtab: Option<UnrolledBlock>,
symtab: Option<UnrolledBlock>,
}
impl PartialBitcodeModule {
pub(self) fn new(identification: UnrolledBlock) -> Self {
Self {
identification: identification,
module: None,
strtab: None,
symtab: None,
}
}
pub(self) fn reify(self) -> Result<BitcodeModule, Error> {
let mut ctx = PartialMapCtx::default();
let strtab = Strtab::try_map(
&self
.strtab
.ok_or_else(|| Error::Unroll("missing STRTAB_BLOCK for bitcode module".into()))?,
&mut ctx,
)
.map_err(BlockMapError::Strtab)?;
ctx.strtab = strtab;
let identification = Identification::try_map(&self.identification, &mut ctx)
.map_err(BlockMapError::Identification)?;
let module = Module::try_map(
&self
.module
.ok_or_else(|| Error::Unroll("missing MODULE_BLOCK for bitcode module".into()))?,
&mut ctx,
)
.map_err(BlockMapError::Module)?;
let symtab = self
.symtab
.map(|s| Symtab::try_map(&s, &mut ctx))
.transpose()
.map_err(BlockMapError::Symtab)?;
#[allow(clippy::unwrap_used)]
Ok(BitcodeModule {
identification: identification,
module: module,
strtab: ctx.strtab,
symtab: symtab,
})
}
}
#[derive(Debug)]
pub struct BitcodeModule {
pub identification: Identification,
pub module: Module,
pub strtab: Strtab,
pub symtab: Option<Symtab>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unrolled_record_try_string() {
let record = UnrolledRecord(Record {
abbrev_id: None,
code: 0,
fields: b"\xff\xffvalid string!".iter().map(|b| *b as u64).collect(),
});
assert_eq!(record.try_string(2).unwrap(), "valid string!");
assert_eq!(record.try_string(8).unwrap(), "string!");
assert!(record.try_string(0).is_err());
assert!(record.try_string(record.0.fields.len()).is_err());
assert!(record.try_string(record.0.fields.len() - 1).is_err());
}
#[test]
fn test_unrolled_record_try_blob() {
let record = UnrolledRecord(Record {
abbrev_id: None,
code: 0,
fields: b"\xff\xffvalid string!".iter().map(|b| *b as u64).collect(),
});
assert_eq!(record.try_blob(0).unwrap(), b"\xff\xffvalid string!");
assert_eq!(record.try_blob(8).unwrap(), b"string!");
assert!(record.try_blob(record.0.fields.len()).is_err());
assert!(record.try_blob(record.0.fields.len() - 1).is_err());
}
}