Skip to content

Commit

Permalink
[bugfix] prevent predict hang when subthread or subproc exception (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji authored Nov 28, 2024
1 parent 196ba31 commit eddfefd
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
4 changes: 3 additions & 1 deletion tzrec/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from enum import Enum


Expand All @@ -22,3 +22,5 @@ class Mode(Enum):


EASYREC_VERSION = "0.7.5"

PREDICT_QUEUE_TIMEOUT = int(os.environ.get("PREDICT_QUEUE_TIMEOUT") or 600)
14 changes: 7 additions & 7 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
is_trt_predict,
write_mapping_file_for_input_tile,
)
from tzrec.constant import Mode
from tzrec.constant import PREDICT_QUEUE_TIMEOUT, Mode
from tzrec.datasets.dataset import BaseDataset, BaseWriter, create_writer
from tzrec.datasets.utils import Batch, RecordBatchTensor
from tzrec.features.feature import (
Expand Down Expand Up @@ -1101,20 +1101,20 @@ def _write(

def _write_loop(output_cols: List[str]) -> None:
while True:
predictions, reserves = pred_queue.get()
predictions, reserves = pred_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
if predictions is None:
break
assert predictions is not None and reserves is not None
_write(predictions, reserves, output_cols)

def _forward_loop() -> None:
while True:
batch = data_queue.get()
batch = data_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
if batch is None:
break
assert batch is not None
pred = _forward(batch)
pred_queue.put(pred)
pred_queue.put(pred, timeout=PREDICT_QUEUE_TIMEOUT)

forward_t_list = []
write_t = None
Expand All @@ -1136,7 +1136,7 @@ def _forward_loop() -> None:
write_t = Thread(target=_write_loop, args=(output_cols,))
write_t.start()
else:
data_queue.put(batch)
data_queue.put(batch, timeout=PREDICT_QUEUE_TIMEOUT)

if is_local_rank_zero:
plogger.log(i_step)
Expand All @@ -1147,10 +1147,10 @@ def _forward_loop() -> None:
break

for _ in range(predict_threads):
data_queue.put(None)
data_queue.put(None, timeout=PREDICT_QUEUE_TIMEOUT)
for t in forward_t_list:
t.join()
pred_queue.put((None, None))
pred_queue.put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
assert write_t is not None
write_t.join()
writer.close()
Expand Down
29 changes: 17 additions & 12 deletions tzrec/tools/tdm/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch import distributed as dist
from torch.distributed import ReduceOp

from tzrec.constant import Mode
from tzrec.constant import PREDICT_QUEUE_TIMEOUT, Mode
from tzrec.datasets.data_parser import DataParser
from tzrec.datasets.dataset import BaseWriter, create_writer
from tzrec.datasets.sampler import TDMPredictSampler
Expand Down Expand Up @@ -83,10 +83,10 @@ def _tdm_predict_data_worker(
sampler.init_sampler(n_cluster)

while True:
record_batch_t, node_ids = in_queue.get()
record_batch_t, node_ids = in_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)

if record_batch_t is None:
out_queue.put((None, None, None))
out_queue.put((None, None, None), timeout=PREDICT_QUEUE_TIMEOUT)
time.sleep(10)
break

Expand All @@ -111,7 +111,10 @@ def _tdm_predict_data_worker(
output_data = data_parser.parse(updated_inputs)
batch = data_parser.to_batch(output_data, force_no_tile=True)

out_queue.put((batch, record_batch_t, updated_inputs[item_id_field]))
out_queue.put(
(batch, record_batch_t, updated_inputs[item_id_field]),
timeout=PREDICT_QUEUE_TIMEOUT,
)


def tdm_retrieval(
Expand Down Expand Up @@ -285,24 +288,26 @@ def _forward(
def _forward_loop(data_queue: Queue, pred_queue: Queue, layer_id: int) -> None:
stop_cnt = 0
while True:
batch, record_batch_t, node_ids = data_queue.get()
batch, record_batch_t, node_ids = data_queue.get(
timeout=PREDICT_QUEUE_TIMEOUT
)
if batch is None:
stop_cnt += 1
if stop_cnt == num_worker_per_level:
for _ in range(num_worker_per_level):
pred_queue.put((None, None))
pred_queue.put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
break
else:
continue
assert batch is not None
pred = _forward(batch, record_batch_t, node_ids, layer_id)
pred_queue.put(pred)
pred_queue.put(pred, timeout=PREDICT_QUEUE_TIMEOUT)

def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None:
total = 0
recall = 0
while True:
record_batch_t, node_ids = pred_queue.get()
record_batch_t, node_ids = pred_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
if record_batch_t is None:
break

Expand All @@ -326,7 +331,7 @@ def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None:
)
total += cur_batch_size
recall += np.sum(retrieval_result)
metric_queue.put((total, recall))
metric_queue.put((total, recall), timeout=PREDICT_QUEUE_TIMEOUT)

in_queues = [Queue(maxsize=2) for _ in range(max_level - first_recall_layer + 1)]
out_queues = [Queue(maxsize=2) for _ in range(max_level - first_recall_layer)]
Expand Down Expand Up @@ -369,7 +374,7 @@ def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None:
while True:
try:
batch = next(infer_iterator)
in_queues[0].put((batch.reserves, None))
in_queues[0].put((batch.reserves, None), timeout=PREDICT_QUEUE_TIMEOUT)
if is_local_rank_zero:
plogger.log(i_step)
if is_profiling:
Expand All @@ -379,15 +384,15 @@ def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None:
break

for _ in range(num_worker_per_level):
in_queues[0].put((None, None))
in_queues[0].put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
for p in data_p_list:
p.join()
for t in forward_t_list:
t.join()
write_t.join()
writer.close()

total, recall = metric_queue.get()
total, recall = metric_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
total_t = torch.tensor(total, device=device)
recall_t = torch.tensor(recall, device=device)
dist.all_reduce(total_t, op=ReduceOp.SUM)
Expand Down

0 comments on commit eddfefd

Please sign in to comment.