-
Notifications
You must be signed in to change notification settings - Fork 0
/
pyrela.pyi
156 lines (91 loc) · 3.24 KB
/
pyrela.pyi
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""A stub file for rela module."""
from typing import List, Optional, Dict, Tuple, overload
import torch
TensorDict = Dict[str, torch.Tensor]
class FutureReply:
def get(self) -> TensorDict: ...
def is_null(self) -> bool: ...
class Batcher:
def __init__(self, batch_size: int): ...
def send(self, t: TensorDict) -> FutureReply: ...
def get(self) -> TensorDict: ...
def set(self, t: TensorDict): ...
class BatchRunner:
@overload
def __init__(
self,
py_model: torch.jit.ScriptModule,
device: str,
max_batch_size: int,
methods: List[str],
): ...
@overload
def __init__(self, py_model: torch.jit.ScriptModule, device: str): ...
def add_method(self, method: str, batch_size: int): ...
def start(self): ...
def stop(self): ...
def update_model(self, py_model: torch.jit.ScriptModule): ...
def set_log_freq(self, log_freq: int): ...
def block_call(self, method: str, t: TensorDict): ...
def call(self, method: str, d: TensorDict) -> FutureReply: ...
class ThreadLoop: ...
class Context:
def __init__(self): ...
def push_thread_loop(self, env: ThreadLoop) -> int: ...
def start(self): ...
def pause(self): ...
def resume(self): ...
def join(self): ...
def terminated(self): ...
class RNNTransition:
obs: Dict[str, torch.Tensor]
action: Dict[str, torch.Tensor]
h0: Dict[str, torch.Tensor]
reward: torch.Tensor
terminal: torch.Tensor
bootstrap: torch.Tensor
seq_len: torch.Tensor
def to_dict(self) -> Dict[str, torch.Tensor]: ...
def to_device(self, device: str): ...
class RNNPrioritizedReplay:
def __init__(
self, capacity: int, seed: int, alpha: float, beta: float, prefetch: int
) -> None: ...
def clear(self): ...
def terminate(self): ...
def size(self) -> int: ...
def num_add(self) -> int: ...
def sample(
self, batchsize: int, device: str
) -> Tuple[RNNTransition, torch.Tensor]: ...
def update_priority(self, priority: torch.Tensor): ...
def get(self, idx: int) -> RNNTransition: ...
# TensorDict utils.
def tensor_dict_stack(vec: TensorDict, stack_dim: int) -> TensorDict: ...
def tensor_dict_eq(d0: TensorDict, d1: TensorDict) -> bool: ...
def tensor_dict_index(batch: TensorDict, i: int) -> TensorDict: ...
def tensor_dict_narrow(
batch: TensorDict, dim: int, i: int, len: int, squeeze: bool
) -> TensorDict: ...
def tensor_dict_clone(d: TensorDict) -> TensorDict: ...
def tensor_dict_zeros_like(d: TensorDict) -> TensorDict: ...
class FFTransition:
obs: TensorDict
action: TensorDict
reward: torch.Tensor
terminal: torch.Tensor
bootstrap: torch.Tensor
next_obs: torch.Tensor
class FFPrioritizedReplay:
def __init__(
self, capacity: int, seed: int, alpha: float, beta: float, prefetch: int
) -> None: ...
def clear(self): ...
def terminate(self): ...
def size(self) -> int: ...
def num_add(self) -> int: ...
def sample(
self, batchsize: int, device: str
) -> Tuple[FFTransition, torch.Tensor]: ...
def update_priority(self, priority: torch.Tensor): ...
def get(self, idx: int) -> FFTransition: ...