Skip to content

Commit

Permalink
feat: Revamp file reading and writing (#4906)
Browse files Browse the repository at this point in the history
This PR:
- changes the implementation of `readBinFile` and `readFile` to only
require two system calls (`stat` + `read`) instead of one `read` per
1024 byte chunk.
- fixes a bug where `Handle.getLine` would get tripped up by a NUL
character in the line and cut the string off. This is caused by the fact
that the original implementation uses `strlen` and `lean_mk_string`
which is the backer of `mk_string` does so as well.
- fixes a bug where `Handle.putStr` and thus by extension `writeFile`
would get tripped up by a NUL char in the line and cut the string off.
Cause here is the use of `fputs` when a NUL char is possible.

Closes: #4891 
Closes: #3546
Closes: #3741
  • Loading branch information
hargoniX authored Aug 7, 2024
1 parent 574066b commit 473b345
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 46 deletions.
48 changes: 31 additions & 17 deletions src/Init/System/IO.lean
Original file line number Diff line number Diff line change
Expand Up @@ -470,31 +470,23 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α :=
def Handle.putStrLn (h : Handle) (s : String) : IO Unit :=
h.putStr (s.push '\n')

partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do
let rec loop (acc : ByteArray) : IO ByteArray := do
let buf ← h.read 1024
if buf.isEmpty then
return acc
else
loop (acc ++ buf)
loop ByteArray.empty
loop buf

partial def Handle.readToEnd (h : Handle) : IO String := do
let rec loop (s : String) := do
let line ← h.getLine
if line.isEmpty then
return s
else
loop (s ++ line)
loop ""

def readBinFile (fname : FilePath) : IO ByteArray := do
let h ← Handle.mk fname Mode.read
h.readBinToEnd
partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do
h.readBinToEndInto .empty

def readFile (fname : FilePath) : IO String := do
let h ← Handle.mk fname Mode.read
h.readToEnd
def Handle.readToEnd (h : Handle) : IO String := do
let data ← h.readBinToEnd
match String.fromUTF8? data with
| some s => return s
| none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data."

partial def lines (fname : FilePath) : IO (Array String) := do
let h ← Handle.mk fname Mode.read
Expand Down Expand Up @@ -600,6 +592,28 @@ end System.FilePath

namespace IO

namespace FS

def readBinFile (fname : FilePath) : IO ByteArray := do
-- Requires metadata so defined after metadata
let mdata ← fname.metadata
let size := mdata.byteSize.toUSize
let handle ← IO.FS.Handle.mk fname .read
let buf ←
if size > 0 then
handle.read mdata.byteSize.toUSize
else
pure <| ByteArray.mkEmpty 0
handle.readBinToEndInto buf

def readFile (fname : FilePath) : IO String := do
let data ← readBinFile fname
match String.fromUTF8? data with
| some s => return s
| none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data."

end FS

def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do
let prev ← setStdin h
try x finally discard <| setStdin prev
Expand Down
45 changes: 16 additions & 29 deletions src/runtime/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,43 +485,30 @@ extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_write(b_obj_arg h, b_obj_arg
}
}

/*
Handle.getLine : (@& Handle) → IO Unit
The line returned by `lean_io_prim_handle_get_line`
is truncated at the first '\0' character and the
rest of the line is discarded. */
/* Handle.getLine : (@& Handle) → IO Unit */
extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_get_line(b_obj_arg h, obj_arg /* w */) {
FILE * fp = io_get_handle(h);
const int buf_sz = 64;
char buf_str[buf_sz]; // NOLINT
std::string result;
bool first = true;
while (true) {
char * out = std::fgets(buf_str, buf_sz, fp);
if (out != nullptr) {
if (strlen(buf_str) < buf_sz-1 || buf_str[buf_sz-2] == '\n') {
if (first) {
return io_result_mk_ok(mk_string(out));
} else {
result.append(out);
return io_result_mk_ok(mk_string(result));
}
}
result.append(out);
} else if (std::feof(fp)) {
clearerr(fp);
return io_result_mk_ok(mk_string(result));
} else {
return io_result_mk_error(decode_io_error(errno, nullptr));
}
first = false;
char* buf = NULL;
size_t n = 0;
ssize_t read = getline(&buf, &n, fp);
if (read != -1) {
obj_res ret = io_result_mk_ok(mk_string_from_bytes(buf, read));
free(buf);
return ret;
} else if (std::feof(fp)) {
clearerr(fp);
return io_result_mk_ok(mk_string(""));
} else {
return io_result_mk_error(decode_io_error(errno, nullptr));
}
}

/* Handle.putStr : (@& Handle) → (@& String) → IO Unit */
extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_put_str(b_obj_arg h, b_obj_arg s, obj_arg /* w */) {
FILE * fp = io_get_handle(h);
if (std::fputs(lean_string_cstr(s), fp) != EOF) {
usize n = lean_string_size(s) - 1; // - 1 to ignore the terminal NULL byte.
usize m = std::fwrite(lean_string_cstr(s), 1, n, fp);
if (m == n) {
return io_result_mk_ok(box(0));
} else {
return io_result_mk_error(decode_io_error(errno, nullptr));
Expand Down
1 change: 1 addition & 0 deletions src/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ inline size_t string_capacity(object * o) { return lean_string_capacity(o); }
inline uint32 char_default_value() { return lean_char_default_value(); }
inline obj_res alloc_string(size_t size, size_t capacity, size_t len) { return lean_alloc_string(size, capacity, len); }
inline obj_res mk_string(char const * s) { return lean_mk_string(s); }
inline obj_res mk_string_from_bytes(char const * s, size_t sz) { return lean_mk_string_from_bytes(s, sz); }
LEAN_EXPORT obj_res mk_ascii_string_unchecked(std::string const & s);
LEAN_EXPORT obj_res mk_string(std::string const & s);
LEAN_EXPORT std::string string_to_std(b_obj_arg o);
Expand Down
14 changes: 14 additions & 0 deletions tests/lean/run/3546.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def test : IO Unit := do
let tmpFile := "3546.tmp"
let firstLine := "foo\u0000bar\n"
let content := firstLine ++ "hello world\nbye"
IO.FS.writeFile tmpFile content
let handle ← IO.FS.Handle.mk tmpFile .read
let firstReadLine ← handle.getLine
let cond := firstLine == firstReadLine && firstReadLine.length == 8 -- paranoid
IO.println cond
IO.FS.removeFile tmpFile

/-- info: true -/
#guard_msgs in
#eval test

0 comments on commit 473b345

Please sign in to comment.