From f8de99b751c6afb1a9e0e1115fa872468469c7fe Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 26 Nov 2023 22:59:56 +0000 Subject: [PATCH] debug scatter --- .github/workflows/full_gpu_testing.yml | 3 ++- .github/workflows/full_testing.yml | 3 ++- .github/workflows/testing.yml | 2 +- torch_geometric/typing.py | 4 ++-- torch_geometric/utils/scatter.py | 10 ++++++++-- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.github/workflows/full_gpu_testing.yml b/.github/workflows/full_gpu_testing.yml index 6be1a15f7712..88520a2672c1 100644 --- a/.github/workflows/full_gpu_testing.yml +++ b/.github/workflows/full_gpu_testing.yml @@ -4,6 +4,7 @@ on: # yamllint disable-line rule:truthy workflow_dispatch: schedule: - cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST + pull_request: jobs: @@ -29,7 +30,7 @@ jobs: pip install -e .[full,test] - name: Run tests - timeout-minutes: 20 + timeout-minutes: 200 run: | FULL_TEST=1 pytest shell: bash diff --git a/.github/workflows/full_testing.yml b/.github/workflows/full_testing.yml index f178f467d6eb..309c0d631d55 100644 --- a/.github/workflows/full_testing.yml +++ b/.github/workflows/full_testing.yml @@ -4,6 +4,7 @@ on: # yamllint disable-line rule:truthy workflow_dispatch: schedule: - cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST + pull_request: jobs: @@ -50,7 +51,7 @@ jobs: pip install -e .[full,test] - name: Run tests - timeout-minutes: 20 + timeout-minutes: 200 run: | FULL_TEST=1 pytest --cov --cov-report=xml shell: bash diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 58fffd9341cd..5d64c086251e 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -43,7 +43,7 @@ jobs: - name: Run tests if: steps.changed-files-specific.outputs.only_changed != 'true' - timeout-minutes: 10 + timeout-minutes: 100 run: | pytest --cov --cov-report=xml --durations 10 diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 38572910ece3..8e469dac599e 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -41,7 +41,7 @@ WITH_SEGMM = False WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add') WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr') - WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort') + WITH_INDEX_SORT = False WITH_METIS = hasattr(pyg_lib, 'partition') WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) @@ -64,7 +64,7 @@ try: import torch_scatter # noqa - WITH_TORCH_SCATTER = True + WITH_TORCH_SCATTER = False except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn(f"An issue occurred while importing 'torch-scatter'. " diff --git a/torch_geometric/utils/scatter.py b/torch_geometric/utils/scatter.py index d9207c90a568..e06d0b516764 100644 --- a/torch_geometric/utils/scatter.py +++ b/torch_geometric/utils/scatter.py @@ -91,8 +91,7 @@ def scatter(src: Tensor, index: Tensor, dim: int = 0, f" package, but it was not found") index = broadcast(index, src, dim) - return src.new_zeros(size).scatter_reduce_( - dim, index, src, reduce=f'a{reduce}', include_self=False) + return _scatter_min_or_max(src, index, dim, size, reduce) return torch_scatter.scatter(src, index, dim, dim_size=dim_size, reduce=reduce) @@ -117,6 +116,13 @@ def scatter(src: Tensor, index: Tensor, dim: int = 0, raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'") + @torch._dynamo.optimize() + def _scatter_min_or_max(src: Tensor, index: Tensor, dim: int, size: int, + reduce: str): + return src.new_zeros(size).scatter_reduce_(dim, index, src, + reduce=f'a{reduce}', + include_self=False) + else: # pragma: no cover def scatter(src: Tensor, index: Tensor, dim: int = 0,