Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing dict in datapipe/dataset will have memory leak problem #1183

Open
qmpzzpmq opened this issue Jun 14, 2023 · 4 comments
Open

Passing dict in datapipe/dataset will have memory leak problem #1183

qmpzzpmq opened this issue Jun 14, 2023 · 4 comments

Comments

@qmpzzpmq
Copy link

qmpzzpmq commented Jun 14, 2023

🐛 Describe the bug

Passing dict in datapipe or dataset will casuse memory leak

from copy import deepcopy
import gc

from memory_profiler import profile
import torch
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2


def build_dp1(num_batch):
    item_list = list()
    for idx in range(num_batch):
        item = {
            "id": idx,
            "clean": {
                "path": str(idx),
                "id": idx,
            },
            "noisy":{
                "path": str(idx),
                "id": idx,
            },
        }
        item_list.append(item)
    return IterableWrapper(item_list)

def build_dp2(num_batch):
    item_list = list()
    for idx in range(num_batch):
        item = {
            "id": idx,
            "clean_path": str(idx),
            "clean_id": idx,
            "noisy_path": str(idx),
            "noisy_id": idx,
        }
        item_list.append(item)
    return IterableWrapper(item_list)

def add_audio1(item):
    item["clean"]["audio"] = torch.randn([5000, 10])
    item["noisy"]["audio"] = torch.randn([5000, 10])
    return item

def add_audio2(item):
    new_item = deepcopy(item)
    new_item["clean"]["audio"] = torch.randn([5000, 10])
    new_item["noisy"]["audio"] = torch.randn([5000, 10])
    return new_item

def add_audio3(item):
    item["clean_audio"] = torch.randn([5000, 10])
    item["noisy_audio"] = torch.randn([5000, 10])
    return item

class MyDataset1(torch.utils.data.Dataset):
    def __init__(self, datalen):
        super().__init__()
        self.datalen = datalen

    def __getitem__(self, index):
        item = {
            "id": index,
            "clean_path": str(index),
            "clean_id": index,
            "clean_audio": torch.randn([5000, 10]),
            "noisy_path": str(index),
            "noisy_id": index,
            "noisy_audio": torch.randn([5000, 10]),
        }
        return item

    def __len__(self):
        return self.datalen

class MyDataset2(torch.utils.data.Dataset):
    def __init__(self, datalen):
        super().__init__()
        self.datalen = datalen

    def __getitem__(self, index):
        return torch.randn([5000, 10]), torch.randn([5000, 10])

    def __len__(self):
        return self.datalen

@profile
def datapipe(num_batch):
    dp = build_dp2(num_batch).map(add_audio3)
    dl = DataLoader2(dp)
    for i, batch in enumerate(dl):
        pass
    pass
    del dp, dl

@profile
def dataset1(num_batch):
    ds = MyDataset1(num_batch)
    dl = DataLoader(ds)
    for i, batch in enumerate(dl):
        pass
    pass
    del ds, dl

@profile
def dataset2(num_batch):
    ds = MyDataset2(num_batch)
    dl = DataLoader(ds)
    for i, batch in enumerate(dl):
        pass
    pass
    del ds, dl

num_batch = 1000

gc.collect()
datapipe(num_batch)
gc.collect()
dataset1(num_batch)
gc.collect()
dataset2(num_batch)
gc.collect()


num_batch = 5000

gc.collect()
datapipe(num_batch)
gc.collect()
dataset1(num_batch)
gc.collect()
dataset2(num_batch)
gc.collect()

output:

Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    88    328.1 MiB    328.1 MiB           1   @profile
    89                                         def datapipe(num_batch):
    90    328.4 MiB      0.3 MiB           1       dp = build_dp2(num_batch).map(add_audio3)
    91    330.6 MiB      2.2 MiB           1       dl = DataLoader2(dp)
    92    714.3 MiB    383.6 MiB        1001       for i, batch in enumerate(dl):
    93    714.3 MiB      0.0 MiB        1000           pass
    94    714.3 MiB      0.0 MiB           1       pass
    95    714.3 MiB      0.0 MiB           1       del dp, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    97    714.4 MiB    714.4 MiB           1   @profile
    98                                         def dataset1(num_batch):
    99    714.4 MiB      0.0 MiB           1       ds = MyDataset1(num_batch)
   100    714.4 MiB      0.0 MiB           1       dl = DataLoader(ds)
   101    716.9 MiB      2.5 MiB        1001       for i, batch in enumerate(dl):
   102    716.9 MiB      0.0 MiB        1000           pass
   103    716.9 MiB      0.0 MiB           1       pass
   104    716.9 MiB      0.0 MiB           1       del ds, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   106    716.9 MiB    716.9 MiB           1   @profile
   107                                         def dataset2(num_batch):
   108    716.9 MiB      0.0 MiB           1       ds = MyDataset2(num_batch)
   109    716.9 MiB      0.0 MiB           1       dl = DataLoader(ds)
   110    716.9 MiB      0.0 MiB        1001       for i, batch in enumerate(dl):
   111    716.9 MiB      0.0 MiB        1000           pass
   112    716.9 MiB      0.0 MiB           1       pass
   113    716.9 MiB      0.0 MiB           1       del ds, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    88    716.9 MiB    716.9 MiB           1   @profile
    89                                         def datapipe(num_batch):
    90    717.0 MiB      0.0 MiB           1       dp = build_dp2(num_batch).map(add_audio3)
    91    721.6 MiB      4.6 MiB           1       dl = DataLoader2(dp)
    92   2254.1 MiB   1532.6 MiB        5001       for i, batch in enumerate(dl):
    93   2254.1 MiB      0.0 MiB        5000           pass
    94   2254.1 MiB      0.0 MiB           1       pass
    95   2252.1 MiB     -2.0 MiB           1       del dp, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    97   2251.5 MiB   2251.5 MiB           1   @profile
    98                                         def dataset1(num_batch):
    99   2251.5 MiB      0.0 MiB           1       ds = MyDataset1(num_batch)
   100   2251.5 MiB      0.0 MiB           1       dl = DataLoader(ds)
   101   2251.5 MiB -7642068.4 MiB        5001       for i, batch in enumerate(dl):
   102   2251.5 MiB -7640538.2 MiB        5000           pass
   103    721.3 MiB  -1530.2 MiB           1       pass
   104    721.3 MiB      0.0 MiB           1       del ds, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   106    721.3 MiB    721.3 MiB           1   @profile
   107                                         def dataset2(num_batch):
   108    721.3 MiB      0.0 MiB           1       ds = MyDataset2(num_batch)
   109    721.3 MiB      0.0 MiB           1       dl = DataLoader(ds)
   110    721.3 MiB      0.0 MiB        5001       for i, batch in enumerate(dl):
   111    721.3 MiB      0.0 MiB        5000           pass
   112    721.3 MiB      0.0 MiB           1       pass
   113    721.3 MiB      0.0 MiB           1       del ds, dl

Versions

torch version: 2.0.0
torchdata version: 0.6.0

It is clear that is pasing the dict of tensor memory will leak but list of tensor will not.

I used dict of tensor in my model training, and I found the training faied multiple times all since of memory leak. And I tried to used Tensordict(https://pytorch.org/rl/tensordict/), but it cannot contains the string. I need string during my datapipes passing (str to tensor encode in one of datapipes).

@andrew-bydlon
Copy link

I'm also using dictionaries and see a memory leak. I'm highlighting a different issue but I'm seeing a small increase in usage over time as well:

#1185

@qmpzzpmq
Copy link
Author

@andrew-bydlon I got a temp method to fix the issue. I split my original dict as two dict, one only contains the tensor another without tensor.

@sakimarquis
Copy link

Same problem with python dict and tensordict.

@hzphzp
Copy link

hzphzp commented Oct 9, 2024

same!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants