Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Co-authored-by: TrellixVulnTeam <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Oct 5, 2022
1 parent d8ff64f commit 9d757d6
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 63 deletions.
3 changes: 2 additions & 1 deletion pl_bolts/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor

from pl_bolts.datasets import LightDataset
from pl_bolts.datasets.utils import safe_extract_tarfile
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -118,7 +119,7 @@ def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]:
def _extract_archive_save_torch(self, download_path):
# extract achieve
with tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar:
tar.extractall(path=download_path)
safe_extract_tarfile(tar, path=download_path)
# this is internal path in the archive
path_content = os.path.join(download_path, "cifar-10-batches-py")

Expand Down
63 changes: 1 addition & 62 deletions pl_bolts/datasets/imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import gzip
import hashlib
import os
import shutil
import sys
import tarfile
import tempfile
import zipfile
from contextlib import contextmanager

import numpy as np
import torch

from pl_bolts.datasets.utils import extract_archive
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

PY3 = sys.version_info[0] == 3

if _TORCHVISION_AVAILABLE:
from torchvision.datasets import ImageNet
from torchvision.datasets.imagenet import load_meta_file
Expand Down Expand Up @@ -247,59 +242,3 @@ def get_tmp_dir():
META_FILE = "meta.bin"

torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))


@under_review()
def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)

PY3 = sys.version_info[0] == 3

if _is_tar(from_path):
with tarfile.open(from_path, "r") as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path):
with tarfile.open(from_path, "r:gz") as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path) and PY3:
# .tar.xz archive only supported in Python 3.x
with tarfile.open(from_path, "r:xz") as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, "r") as z:
z.extractall(to_path)
else:
raise ValueError(f"Extraction of {from_path} not supported")

if remove_finished:
os.remove(from_path)


@under_review()
def _is_targz(filename):
return filename.endswith(".tar.gz")


@under_review()
def _is_tarxz(filename):
return filename.endswith(".tar.xz")


@under_review()
def _is_gzip(filename):
return filename.endswith(".gz") and not filename.endswith(".tar.gz")


@under_review()
def _is_tar(filename):
return filename.endswith(".tar")


@under_review()
def _is_zip(filename):
return filename.endswith(".zip")
66 changes: 66 additions & 0 deletions pl_bolts/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import gzip
import os
import tarfile
import zipfile
from typing import List, Optional

import torch
from torch.utils.data.dataset import random_split

Expand Down Expand Up @@ -55,3 +61,63 @@ def to_tensor(arrays: TArrays) -> torch.Tensor:
Tensor of the integers
"""
return torch.tensor(arrays)


def is_within_directory(directory: str, target: str) -> bool:
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory


def safe_extract_tarfile(
tar: tarfile.TarFile,
path: str = ".",
members: Optional[List[tarfile.TarInfo]] = None,
*,
numeric_owner: bool = False,
) -> None:
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise RuntimeError(f"Attempted Path Traversal in Tar File {tar.name} with member: {member.name}")

tar.extractall(path, members, numeric_owner=numeric_owner)


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
if to_path is None:
to_path = os.path.dirname(from_path)

extracted = False
for fn in (_extract_tar, _extract_gzip, _extract_zip):
try:
fn(from_path, to_path)
extracted = True
break
except (tarfile.TarError, zipfile.BadZipfile, OSError):
continue

if not extracted:
raise ValueError(f"Extraction of {from_path} not supported")

if remove_finished:
os.remove(from_path)


def _extract_tar(from_path: str, to_path: str) -> None:
with tarfile.open(from_path, "r:*") as tar:
safe_extract_tarfile(tar, path=to_path)


def _extract_gzip(from_path: str, to_path: str) -> None:
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())


def _extract_zip(from_path: str, to_path: str) -> None:
with zipfile.ZipFile(from_path, "r") as z:
z.extractall(to_path)

0 comments on commit 9d757d6

Please sign in to comment.