Skip to content

Commit

Permalink
feat: synchronous execution of task continuations
Browse files Browse the repository at this point in the history
  • Loading branch information
Kha committed Dec 2, 2023
1 parent 92f1755 commit 8348d0f
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 36 deletions.
7 changes: 5 additions & 2 deletions src/Init/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,10 @@ set_option linter.unusedVariables.funArgs false in
be available and then calls `f` on the result.
`prio`, if provided, is the priority of the task.
If `sync` is set to true, `f` is executed on the current thread if `x` has already finished.
-/
@[noinline, extern "lean_task_map"]
protected def map {α : Type u} {β : Type v} (f : α → β) (x : Task α) (prio := Priority.default) : Task β :=
protected def map (f : α → β) (x : Task α) (prio := Priority.default) (sync := false) : Task β :=
⟨f x.get⟩

set_option linter.unusedVariables.funArgs false in
Expand All @@ -424,9 +425,11 @@ for the value of `x` to be available and then calls `f` on the result,
resulting in a new task which is then run for a result.
`prio`, if provided, is the priority of the task.
If `sync` is set to true, `f` is executed on the current thread if `x` has already finished.
-/
@[noinline, extern "lean_task_bind"]
protected def bind {α : Type u} {β : Type v} (x : Task α) (f : α → Task β) (prio := Priority.default) : Task β :=
protected def bind (x : Task α) (f : α → Task β) (prio := Priority.default) (sync := false) :
Task β :=
⟨(f x.get).get⟩

end Task
Expand Down
42 changes: 26 additions & 16 deletions src/Init/System/IO.lean
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,23 @@ opaque asTask (act : BaseIO α) (prio := Task.Priority.default) : BaseIO (Task

/-- See `BaseIO.asTask`. -/
@[extern "lean_io_map_task"]
opaque mapTask (f : α → BaseIO β) (t : Task α) (prio := Task.Priority.default) : BaseIO (Task β) :=
opaque mapTask (f : α → BaseIO β) (t : Task α) (prio := Task.Priority.default) (sync := false) :
BaseIO (Task β) :=
Task.pure <$> f t.get

/-- See `BaseIO.asTask`. -/
@[extern "lean_io_bind_task"]
opaque bindTask (t : Task α) (f : α → BaseIO (Task β)) (prio := Task.Priority.default) : BaseIO (Task β) :=
opaque bindTask (t : Task α) (f : α → BaseIO (Task β)) (prio := Task.Priority.default)
(sync := false) : BaseIO (Task β) :=
f t.get

def mapTasks (f : List α → BaseIO β) (tasks : List (Task α)) (prio := Task.Priority.default) : BaseIO (Task β) :=
def mapTasks (f : List α → BaseIO β) (tasks : List (Task α)) (prio := Task.Priority.default)
(sync := false) : BaseIO (Task β) :=
go tasks []
where
go
| t::ts, as =>
BaseIO.bindTask t (fun a => go ts (a :: as)) prio
BaseIO.bindTask t (fun a => go ts (a :: as)) prio sync
| [], as => f as.reverse |>.asTask prio

end BaseIO
Expand All @@ -142,16 +145,20 @@ namespace EIO
act.toBaseIO.asTask prio

/-- `EIO` specialization of `BaseIO.mapTask`. -/
@[inline] def mapTask (f : α → EIO ε β) (t : Task α) (prio := Task.Priority.default) : BaseIO (Task (Except ε β)) :=
BaseIO.mapTask (fun a => f a |>.toBaseIO) t prio
@[inline] def mapTask (f : α → EIO ε β) (t : Task α) (prio := Task.Priority.default)
(sync := false) : BaseIO (Task (Except ε β)) :=
BaseIO.mapTask (fun a => f a |>.toBaseIO) t prio sync

/-- `EIO` specialization of `BaseIO.bindTask`. -/
@[inline] def bindTask (t : Task α) (f : α → EIO ε (Task (Except ε β))) (prio := Task.Priority.default) : BaseIO (Task (Except ε β)) :=
BaseIO.bindTask t (fun a => f a |>.catchExceptions fun e => return Task.pure <| Except.error e) prio
@[inline] def bindTask (t : Task α) (f : α → EIO ε (Task (Except ε β)))
(prio := Task.Priority.default) (sync := false) : BaseIO (Task (Except ε β)) :=
BaseIO.bindTask t (fun a => f a |>.catchExceptions fun e => return Task.pure <| Except.error e)
prio sync

/-- `EIO` specialization of `BaseIO.mapTasks`. -/
@[inline] def mapTasks (f : List α → EIO ε β) (tasks : List (Task α)) (prio := Task.Priority.default) : BaseIO (Task (Except ε β)) :=
BaseIO.mapTasks (fun as => f as |>.toBaseIO) tasks prio
@[inline] def mapTasks (f : List α → EIO ε β) (tasks : List (Task α))
(prio := Task.Priority.default) (sync := false) : BaseIO (Task (Except ε β)) :=
BaseIO.mapTasks (fun as => f as |>.toBaseIO) tasks prio sync

end EIO

Expand Down Expand Up @@ -184,16 +191,19 @@ def sleep (ms : UInt32) : BaseIO Unit :=
EIO.asTask act prio

/-- `IO` specialization of `EIO.mapTask`. -/
@[inline] def mapTask (f : α → IO β) (t : Task α) (prio := Task.Priority.default) : BaseIO (Task (Except IO.Error β)) :=
EIO.mapTask f t prio
@[inline] def mapTask (f : α → IO β) (t : Task α) (prio := Task.Priority.default) (sync := false) :
BaseIO (Task (Except IO.Error β)) :=
EIO.mapTask f t prio sync

/-- `IO` specialization of `EIO.bindTask`. -/
@[inline] def bindTask (t : Task α) (f : α → IO (Task (Except IO.Error β))) (prio := Task.Priority.default) : BaseIO (Task (Except IO.Error β)) :=
EIO.bindTask t f prio
@[inline] def bindTask (t : Task α) (f : α → IO (Task (Except IO.Error β)))
(prio := Task.Priority.default) (sync := false) : BaseIO (Task (Except IO.Error β)) :=
EIO.bindTask t f prio sync

/-- `IO` specialization of `EIO.mapTasks`. -/
@[inline] def mapTasks (f : List α → IO β) (tasks : List (Task α)) (prio := Task.Priority.default) : BaseIO (Task (Except IO.Error β)) :=
EIO.mapTasks f tasks prio
@[inline] def mapTasks (f : List α → IO β) (tasks : List (Task α)) (prio := Task.Priority.default)
(sync := false) : BaseIO (Task (Except IO.Error β)) :=
EIO.mapTasks f tasks prio sync

/-- Check if the task's cancellation flag has been set by calling `IO.cancel` or dropping the last reference to the task. -/
@[extern "lean_io_check_canceled"] opaque checkCanceled : BaseIO Bool
Expand Down
12 changes: 6 additions & 6 deletions src/include/lean/lean.h
Original file line number Diff line number Diff line change
Expand Up @@ -1092,12 +1092,12 @@ LEAN_SHARED lean_obj_res lean_task_spawn_core(lean_obj_arg c, unsigned prio, boo
static inline lean_obj_res lean_task_spawn(lean_obj_arg c, lean_obj_arg prio) { return lean_task_spawn_core(c, lean_unbox(prio), false); }
/* Convert a value `a : A` into `Task A` */
LEAN_SHARED lean_obj_res lean_task_pure(lean_obj_arg a);
LEAN_SHARED lean_obj_res lean_task_bind_core(lean_obj_arg x, lean_obj_arg f, unsigned prio, bool keep_alive);
/* Task.bind (x : Task A) (f : A -> Task B) (prio : Nat) : Task B */
static inline lean_obj_res lean_task_bind(lean_obj_arg x, lean_obj_arg f, lean_obj_arg prio) { return lean_task_bind_core(x, f, lean_unbox(prio), false); }
LEAN_SHARED lean_obj_res lean_task_map_core(lean_obj_arg f, lean_obj_arg t, unsigned prio, bool keep_alive);
/* Task.map (f : A -> B) (t : Task A) (prio : Nat) : Task B */
static inline lean_obj_res lean_task_map(lean_obj_arg f, lean_obj_arg t, lean_obj_arg prio) { return lean_task_map_core(f, t, lean_unbox(prio), false); }
LEAN_SHARED lean_obj_res lean_task_bind_core(lean_obj_arg x, lean_obj_arg f, unsigned prio, bool sync, bool keep_alive);
/* Task.bind (x : Task A) (f : A -> Task B) (prio : Nat) (sync : Bool) : Task B */
static inline lean_obj_res lean_task_bind(lean_obj_arg x, lean_obj_arg f, lean_obj_arg prio, uint8_t sync) { return lean_task_bind_core(x, f, lean_unbox(prio), sync, false); }
LEAN_SHARED lean_obj_res lean_task_map_core(lean_obj_arg f, lean_obj_arg t, unsigned prio, bool sync, bool keep_alive);
/* Task.map (f : A -> B) (t : Task A) (prio : Nat) (sync : Bool) : Task B */
static inline lean_obj_res lean_task_map(lean_obj_arg f, lean_obj_arg t, lean_obj_arg prio, uint8_t sync) { return lean_task_map_core(f, t, lean_unbox(prio), sync, false); }
LEAN_SHARED b_lean_obj_res lean_task_get(b_lean_obj_arg t);
/* Primitive for implementing Task.get : Task A -> A */
static inline lean_obj_res lean_task_get_own(lean_obj_arg t) {
Expand Down
14 changes: 8 additions & 6 deletions src/runtime/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,19 +1011,21 @@ static obj_res lean_io_bind_task_fn(obj_arg f, obj_arg a) {
return object_ref(io_result_get_value(r.raw()), true).steal();
}

/* mapTask {α : Type u} {β : Type} (f : α → BaseIO β) (t : Task α) (prio : Nat) : BaseIO (Task β) */
extern "C" LEAN_EXPORT obj_res lean_io_map_task(obj_arg f, obj_arg t, obj_arg prio, obj_arg) {
/* mapTask (f : α → BaseIO β) (t : Task α) (prio : Nat) (sync : Bool) : BaseIO (Task β) */
extern "C" LEAN_EXPORT obj_res lean_io_map_task(obj_arg f, obj_arg t, obj_arg prio, uint8 sync,
obj_arg) {
object * c = lean_alloc_closure((void*)lean_io_bind_task_fn, 2, 1);
lean_closure_set(c, 0, f);
object * t2 = lean_task_map_core(c, t, lean_unbox(prio), /* keep_alive */ true);
object * t2 = lean_task_map_core(c, t, lean_unbox(prio), sync, /* keep_alive */ true);
return io_result_mk_ok(t2);
}

/* bindTask {α : Type u} {β : Type} (t : Task α) (f : α → BaseIO (Task β)) (prio : Nat) : BaseIO (Task β) */
extern "C" LEAN_EXPORT obj_res lean_io_bind_task(obj_arg t, obj_arg f, obj_arg prio, obj_arg) {
/* bindTask (t : Task α) (f : α → BaseIO (Task β)) (prio : Nat) (sync : Bool) : BaseIO (Task β) */
extern "C" LEAN_EXPORT obj_res lean_io_bind_task(obj_arg t, obj_arg f, obj_arg prio, uint8 sync,
obj_arg) {
object * c = lean_alloc_closure((void*)lean_io_bind_task_fn, 2, 1);
lean_closure_set(c, 0, f);
object * t2 = lean_task_bind_core(t, c, lean_unbox(prio), /* keep_alive */ true);
object * t2 = lean_task_bind_core(t, c, lean_unbox(prio), sync, /* keep_alive */ true);
return io_result_mk_ok(t2);
}

Expand Down
10 changes: 6 additions & 4 deletions src/runtime/object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,9 @@ static obj_res task_map_fn(obj_arg f, obj_arg t, obj_arg) {
return lean_apply_1(f, v);
}

extern "C" LEAN_EXPORT obj_res lean_task_map_core(obj_arg f, obj_arg t, unsigned prio, bool keep_alive) {
if (!g_task_manager) {
extern "C" LEAN_EXPORT obj_res lean_task_map_core(obj_arg f, obj_arg t, unsigned prio,
bool sync, bool keep_alive) {
if (!g_task_manager || (sync && lean_to_task(t)->m_value)) {
return lean_task_pure(apply_1(f, lean_task_get_own(t)));
} else {
lean_task_object * new_task = alloc_task(mk_closure_3_2(task_map_fn, f, t), prio, keep_alive);
Expand Down Expand Up @@ -995,8 +996,9 @@ static obj_res task_bind_fn1(obj_arg x, obj_arg f, obj_arg) {
return nullptr; /* notify queue that task did not finish yet. */
}

extern "C" LEAN_EXPORT obj_res lean_task_bind_core(obj_arg x, obj_arg f, unsigned prio, bool keep_alive) {
if (!g_task_manager) {
extern "C" LEAN_EXPORT obj_res lean_task_bind_core(obj_arg x, obj_arg f, unsigned prio,
bool sync, bool keep_alive) {
if (!g_task_manager || (sync && lean_to_task(x)->m_value)) {
return apply_1(f, lean_task_get_own(x));
} else {
lean_task_object * new_task = alloc_task(mk_closure_3_2(task_bind_fn1, x, f), prio, keep_alive);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ class scoped_task_manager {

inline obj_res task_spawn(obj_arg c, unsigned prio = 0, bool keep_alive = false) { return lean_task_spawn_core(c, prio, keep_alive); }
inline obj_res task_pure(obj_arg a) { return lean_task_pure(a); }
inline obj_res task_bind(obj_arg x, obj_arg f, unsigned prio = 0, bool keep_alive = false) { return lean_task_bind_core(x, f, prio, keep_alive); }
inline obj_res task_map(obj_arg f, obj_arg t, unsigned prio = 0, bool keep_alive = false) { return lean_task_map_core(f, t, prio, keep_alive); }
inline obj_res task_bind(obj_arg x, obj_arg f, unsigned prio = 0, bool sync = false, bool keep_alive = false) { return lean_task_bind_core(x, f, prio, sync, keep_alive); }
inline obj_res task_map(obj_arg f, obj_arg t, unsigned prio = 0, bool sync = false, bool keep_alive = false) { return lean_task_map_core(f, t, prio, sync, keep_alive); }
inline b_obj_res task_get(b_obj_arg t) { return lean_task_get(t); }

inline bool io_check_canceled_core() { return lean_io_check_canceled_core(); }
Expand Down

0 comments on commit 8348d0f

Please sign in to comment.