Skip to content

Commit

Permalink
[Fix] Fix stream context to avoid init error
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Sep 27, 2023
1 parent b488cdc commit b1d2354
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mmengine/model/wrappers/pipeline_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ class MMPipelineParallel(nn.Module):
in_queues: Dict[str, Queue] = {}
out_queues: Dict[str, Queue] = {}
events: List[List[Event]] = []
stream_contexts: List[torch.cuda.StreamContext] = []
# a list of cuda stream contexts
# but PyTorch version < 1.9.0 does not support stream context
# we use Any to avoid error
stream_contexts: List[Any] = []
hook_visited_times: Dict[str, int] = {}

def __init__(self,
Expand Down Expand Up @@ -556,8 +559,10 @@ def _init_events(self) -> List[List[Event]]:
events.append([Event() for _ in range(self.num_pipelines)])
return events

def _init_stream_contexts(self) -> List[torch.cuda.StreamContext]:
def _init_stream_contexts(self) -> List[Any]:
"""Init the stream contexts for execution."""
# PyTorch version < 1.9.0 does not support stream context
# we use Any to avoid error
curr_part_id = -1
inited_streams = {}
stream_contexts = []
Expand Down

0 comments on commit b1d2354

Please sign in to comment.