From 2d595eaf9c97c48a5ce1b68003c56146efe94335 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Wed, 18 Oct 2023 10:58:48 -0700 Subject: [PATCH] check if prefetch_pipeline is None or False (#1450) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1450 The flag prefetch_pipeline can be None. So we can't only check if it's False. Reviewed By: YLGH Differential Revision: D50398167 fbshipit-source-id: f2a6f897e499bbd6a99d8cc7f98ebd65d590ca12 --- torchrec/distributed/planner/shard_estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 729cff9f2..d21400c49 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -158,7 +158,7 @@ def estimate( else False ) # TODO: remove after deprecating fused_params in sharder - if prefetch_pipeline is False: + if not prefetch_pipeline: prefetch_pipeline = ( sharder.fused_params.get("prefetch_pipeline", False) if hasattr(sharder, "fused_params") and sharder.fused_params