-
Notifications
You must be signed in to change notification settings - Fork 2
/
async.go
105 lines (91 loc) · 2.03 KB
/
async.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package async
import (
"sync"
"github.com/yuin/gopher-lua"
)
type AsyncState struct {
channel chan *AsyncResult
wg *sync.WaitGroup
}
type AsyncResult struct {
co *lua.LState
result []lua.LValue
}
func AsyncRun(fn func() []lua.LValue, L *lua.LState) {
_channel := L.GetGlobal("__state").(*lua.LUserData)
s := _channel.Value.(*AsyncState)
s.wg.Add(1)
go func() {
ret := fn()
s.channel <- &AsyncResult{
co: L,
result: ret,
}
}()
}
func WrapAsyncFunc(L *lua.LState, fn *lua.LFunction) *lua.LFunction {
return L.NewFunction(func(L *lua.LState) int {
co, _ := L.NewThread()
args := []lua.LValue{}
for i := 1; i <= L.GetTop(); i++ {
args = append(args, L.Get(i))
}
_, _, rets := L.Resume(co, fn, args...)
if len(rets) > 0 && !(len(rets) == 1 && rets[0] == lua.LNil) {
Schedule(L)
} else {
rets = Schedule(L)
}
for _, ret := range rets {
L.Push(ret)
}
return len(rets)
})
}
func Init(L *lua.LState) {
awaitScript := `
function await(fn, ...)
fn(...)
return coroutine.yield()
end
function async(fn, ...)
coroutine.resume(coroutine.create(fn), ...)
end
`
tmpL := lua.NewState()
tmpL.NewThread()
tmpL.DoString(awaitScript)
awaitFunc := tmpL.GetGlobal("await").(*lua.LFunction).Proto
asyncFunc := tmpL.GetGlobal("async").(*lua.LFunction).Proto
s := &AsyncState{}
s.channel = make(chan *AsyncResult)
s.wg = &sync.WaitGroup{}
ud := L.NewUserData()
ud.Value = s
L.SetGlobal("__state", ud)
L.SetGlobal("await", L.NewFunctionFromProto(awaitFunc))
L.SetGlobal("async", L.NewFunctionFromProto(asyncFunc))
}
func Schedule(L *lua.LState) []lua.LValue {
_channel := L.GetGlobal("__state").(*lua.LUserData)
s := _channel.Value.(*AsyncState)
var vals []lua.LValue
go func() {
s.wg.Wait()
close(s.channel)
}()
for {
select {
case a := <-s.channel:
if a == nil {
return vals
}
_, _, _vals := L.Resume(a.co, nil, a.result...)
if len(_vals) > 0 && !(len(_vals) == 1 && _vals[0] == lua.LNil) {
vals = _vals
}
s.wg.Done()
}
}
return []lua.LValue{}
}