diff --git a/src/fs.rs b/src/fs.rs index da79501..402719c 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -2,8 +2,9 @@ use std::{ cell::RefCell, cmp::max, collections::{BTreeMap, HashMap}, - ffi::OsString, - ffi::{CStr, OsStr}, + ffi::{CStr, OsStr, OsString}, + fs::File, + io::Write, mem::MaybeUninit, os::unix::ffi::{OsStrExt, OsStringExt}, path::Path, @@ -14,10 +15,10 @@ use anyhow::{ensure, Result}; use rustix::{ fd::{AsFd, OwnedFd}, fs::{ - fdatasync, fstat, getxattr, linkat, listxattr, mkdirat, mknodat, openat, readlinkat, - symlinkat, AtFlags, Dir, FileType, Mode, OFlags, CWD, + fstat, getxattr, linkat, listxattr, mkdirat, mknodat, openat, readlinkat, symlinkat, + AtFlags, Dir, FileType, Mode, OFlags, CWD, }, - io::{read_uninit, write, Errno}, + io::{read_uninit, Errno}, }; use zerocopy::IntoBytes; @@ -30,16 +31,19 @@ use crate::{ INLINE_CONTENT_MAX, }; +/// Attempt to use O_TMPFILE + rename to atomically set file contents. +/// Will fall back to a non-atomic write if the target doesn't support O_TMPFILE. fn set_file_contents(dirfd: &OwnedFd, name: &OsStr, stat: &Stat, data: &[u8]) -> Result<()> { match openat( dirfd, ".", - OFlags::WRONLY | OFlags::TMPFILE, + OFlags::WRONLY | OFlags::TMPFILE | OFlags::CLOEXEC, stat.st_mode.into(), ) { Ok(tmp) => { - write(&tmp, data)?; // TODO: make this better - fdatasync(&tmp)?; + let mut tmp = File::from(tmp); + tmp.write_all(data)?; + tmp.sync_data()?; linkat( CWD, proc_self_fd(&tmp), @@ -53,11 +57,12 @@ fn set_file_contents(dirfd: &OwnedFd, name: &OsStr, stat: &Stat, data: &[u8]) -> let fd = openat( dirfd, name, - OFlags::CREATE | OFlags::WRONLY, + OFlags::CREATE | OFlags::WRONLY | OFlags::CLOEXEC, stat.st_mode.into(), )?; - write(&fd, data)?; - fdatasync(&fd)?; + let mut f = File::from(fd); + f.write_all(data)?; + f.sync_data()?; } Err(e) => Err(e)?, } @@ -320,3 +325,32 @@ pub fn create_dumpfile(path: &Path) -> Result<()> { let fs = read_from_path(path, None)?; super::dumpfile::write_dumpfile(&mut std::io::stdout(), &fs) } + +#[cfg(test)] +mod tests { + use super::*; + use rustix::fs::{openat, CWD}; + + #[test] + fn test_write_contents() -> Result<()> { + let td = tempfile::tempdir()?; + let testpath = &td.path().join("testfile"); + let td = openat( + CWD, + td.path(), + OFlags::RDONLY | OFlags::DIRECTORY | OFlags::CLOEXEC, + Mode::from_raw_mode(0), + )?; + let st = Stat { + st_mode: 0o755, + st_uid: 0, + st_gid: 0, + st_mtim_sec: Default::default(), + xattrs: Default::default(), + }; + set_file_contents(&td, OsStr::new("testfile"), &st, b"new contents").unwrap(); + drop(td); + assert_eq!(std::fs::read(testpath)?, b"new contents"); + Ok(()) + } +}