Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework SymbolTableCache::str2sym member #938

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 66 additions & 31 deletions src/elf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::ops::ControlFlow;
use std::ops::Deref as _;
use std::path::Path;
use std::path::PathBuf;
use std::str;

use crate::inspect::FindAddrOpts;
use crate::inspect::ForEachFn;
Expand Down Expand Up @@ -181,14 +182,36 @@ impl EhdrExt<'_> {
}


#[derive(Debug)]
struct SymName {
/// The index of the first byte of the name.
idx: usize,
/// The length of the name.
len: usize,
}

impl SymName {
fn bytes<'strs>(&self, strs: &'strs [u8]) -> &'strs [u8] {
&strs[self.idx..self.idx + self.len]
}

fn name<'strs>(&self, strs: &'strs [u8]) -> Result<&'strs str> {
let bytes = self.bytes(strs);
str::from_utf8(bytes)
.ok()
.ok_or_invalid_data(|| "symbol name `{bytes:?}` is invalid")
}
}


#[derive(Debug)]
struct SymbolTableCache<'mmap> {
/// The cached symbols (in address order).
syms: ElfN_BoxedSyms<'mmap>,
/// The string table.
strs: &'mmap [u8],
/// The cached name to symbol index table (in dictionary order).
str2sym: OnceCell<Box<[(&'mmap str, usize)]>>,
str2sym: OnceCell<Box<[(SymName, usize)]>>,
}

impl<'mmap> SymbolTableCache<'mmap> {
Expand All @@ -200,7 +223,7 @@ impl<'mmap> SymbolTableCache<'mmap> {
}
}

fn create_str2sym<F>(&self, mut filter: F) -> Result<Box<[(&'mmap str, usize)]>>
fn create_str2sym<F>(&self, mut filter: F) -> Result<Box<[(SymName, usize)]>>
where
F: FnMut(&ElfN_Sym<'_>) -> bool,
{
Expand All @@ -210,24 +233,28 @@ impl<'mmap> SymbolTableCache<'mmap> {
.filter(|sym| filter(sym))
.enumerate()
.map(|(i, sym)| {
let name = self
let idx = sym.name() as usize;
let cname = self
.strs
.get(sym.name() as usize..)
.get(idx..)
.ok_or_invalid_input(|| "ELF string table index out of bounds")?
.read_cstr()
.ok_or_invalid_input(|| "no valid string found in ELF string table")?
.to_str()
.map_err(Error::with_invalid_data)
.context("invalid ELF symbol name")?;
.ok_or_invalid_input(|| "no valid string found in ELF string table")?;
let name = SymName {
idx,
// TODO: May want to use `CStr::count_bytes` once
// our MSRV is >=1.79.
len: cname.to_bytes().len(),
};
Ok((name, i))
})
.collect::<Result<Box<[_]>>>()?;

let () = str2sym.sort_by_key(|&(name, _i)| name);
let () = str2sym.sort_by_key(|(name, _i)| name.bytes(self.strs));
Ok(str2sym)
}

fn ensure_str2sym<F>(&self, filter: F) -> Result<&[(&'mmap str, usize)]>
fn ensure_str2sym<F>(&self, filter: F) -> Result<&[(SymName, usize)]>
where
F: FnMut(&ElfN_Sym<'_>) -> bool,
{
Expand Down Expand Up @@ -685,16 +712,12 @@ impl<'mmap> Cache<'mmap> {
})
}

#[cfg(test)]
fn ensure_symtab(&self) -> Result<&ElfN_BoxedSyms<'_>> {
let symtab = self.ensure_symtab_cache()?;
Ok(&symtab.syms)
}

fn ensure_dynsym(&self) -> Result<&ElfN_BoxedSyms<'_>> {
let dynsym = self.ensure_dynsym_cache()?;
Ok(&dynsym.syms)
}

fn parse_strs(&self, section: &str) -> Result<&'mmap [u8]> {
let strs = if let Some(idx) = self.find_section(section)? {
self.section_data_raw(idx)?
Expand All @@ -704,13 +727,13 @@ impl<'mmap> Cache<'mmap> {
Ok(strs)
}

fn ensure_str2symtab(&self) -> Result<&[(&'mmap str, usize)]> {
fn ensure_str2symtab(&self) -> Result<&[(SymName, usize)]> {
let symtab = self.ensure_symtab_cache()?;
let str2sym = symtab.ensure_str2sym(|_sym| true)?;
Ok(str2sym)
}

fn ensure_str2dynsym(&self) -> Result<&[(&'mmap str, usize)]> {
fn ensure_str2dynsym(&self) -> Result<&[(SymName, usize)]> {
let symtab = self.ensure_symtab_cache()?;
let dynsym = self.ensure_dynsym_cache()?;
let str2sym = dynsym.ensure_str2sym(|sym| {
Expand Down Expand Up @@ -869,14 +892,17 @@ impl ElfParser {
opts: &FindAddrOpts,
shdrs: &ElfN_Shdrs<'_>,
syms: &ElfN_BoxedSyms,
str2sym: &'slf [(&'slf str, usize)],
strs: &'slf [u8],
str2sym: &'slf [(SymName, usize)],
) -> Result<Vec<SymInfo<'slf>>> {
let r = find_match_or_lower_bound_by_key(str2sym, name, |&(name, _i)| name);
let r = find_match_or_lower_bound_by_key(str2sym, name.as_bytes(), |(name, _i)| {
name.bytes(strs)
});
match r {
Some(idx) => {
let mut found = vec![];
for (name_visit, sym_i) in str2sym.iter().skip(idx) {
if *name_visit != name {
if name_visit.bytes(strs) != name.as_bytes() {
break
}
let sym_ref = &syms.get(*sym_i).ok_or_invalid_input(|| {
Expand All @@ -885,7 +911,7 @@ impl ElfParser {
let sym = sym_ref.to_64bit();
if sym.st_shndx != SHN_UNDEF {
found.push(SymInfo {
name: Cow::Borrowed(name_visit),
name: Cow::Borrowed(name_visit.name(strs)?),
addr: sym.st_value as Addr,
size: sym.st_size as usize,
// SANITY: We filter out all unsupported symbol
Expand Down Expand Up @@ -913,24 +939,29 @@ impl ElfParser {
opts: &FindAddrOpts,
) -> Result<Vec<SymInfo<'slf>>> {
let shdrs = self.cache.ensure_shdrs()?;
let symtab = self.cache.ensure_symtab()?;
let cache = self.cache.ensure_symtab_cache()?;
let symtab = &cache.syms;
let strs = &cache.strs;
let str2symtab = self.cache.ensure_str2symtab()?;
let syms = self.find_addr_impl(name, opts, shdrs, symtab, str2symtab)?;
let syms = self.find_addr_impl(name, opts, shdrs, symtab, strs, str2symtab)?;
if !syms.is_empty() {
return Ok(syms)
}

let dynsym = self.cache.ensure_dynsym()?;
let cache = self.cache.ensure_dynsym_cache()?;
let dynsym = &cache.syms;
let strs = &cache.strs;
let str2dynsym = self.cache.ensure_str2dynsym()?;
let syms = self.find_addr_impl(name, opts, shdrs, dynsym, str2dynsym)?;
let syms = self.find_addr_impl(name, opts, shdrs, dynsym, strs, str2dynsym)?;
Ok(syms)
}

fn for_each_sym_impl(
&self,
opts: &FindAddrOpts,
syms: &ElfN_BoxedSyms<'_>,
str2sym: &[(&str, usize)],
strs: &[u8],
str2sym: &[(SymName, usize)],
f: &mut ForEachFn<'_>,
) -> Result<()> {
let shdrs = self.cache.ensure_shdrs()?;
Expand All @@ -943,7 +974,7 @@ impl ElfParser {

if sym.matches(opts.sym_type) && sym.st_shndx != SHN_UNDEF {
let sym_info = SymInfo {
name: Cow::Borrowed(name),
name: Cow::Borrowed(name.name(strs)?),
addr: sym.st_value as Addr,
size: sym.st_size as usize,
// SANITY: We filter out all unsupported symbol
Expand All @@ -969,13 +1000,17 @@ impl ElfParser {
/// Perform an operation on each symbol.
#[allow(clippy::needless_borrows_for_generic_args)]
pub(crate) fn for_each(&self, opts: &FindAddrOpts, f: &mut ForEachFn) -> Result<()> {
let symtab = self.cache.ensure_symtab()?;
let cache = self.cache.ensure_symtab_cache()?;
let symtab = &cache.syms;
let strs = &cache.strs;
let str2symtab = self.cache.ensure_str2symtab()?;
let () = self.for_each_sym_impl(opts, symtab, str2symtab, f)?;
let () = self.for_each_sym_impl(opts, symtab, strs, str2symtab, f)?;

let dynsym = self.cache.ensure_dynsym()?;
let cache = self.cache.ensure_dynsym_cache()?;
let dynsym = &cache.syms;
let strs = &cache.strs;
let str2dynsym = self.cache.ensure_str2dynsym()?;
let () = self.for_each_sym_impl(opts, dynsym, str2dynsym, f)?;
let () = self.for_each_sym_impl(opts, dynsym, strs, str2dynsym, f)?;

Ok(())
}
Expand Down
Loading