From 3578bf7c6bdc90e68956e7730086230f8ab138d7 Mon Sep 17 00:00:00 2001 From: Marco Sirabella Date: Wed, 20 Mar 2024 17:26:30 -0700 Subject: [PATCH] MAINT: Allow opening PdfReader as contextmanager To mirror PdfWriter, also hints towards file pointer management now that we keep files open sometimes. --- pypdf/_reader.py | 21 +++++++++++++++++++++ tests/test_reader.py | 4 ++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pypdf/_reader.py b/pypdf/_reader.py index 3c5b60da7..2d3c73d8e 100644 --- a/pypdf/_reader.py +++ b/pypdf/_reader.py @@ -32,6 +32,7 @@ import weakref from io import BytesIO, FileIO, UnsupportedOperation from pathlib import Path +from types import TracebackType from typing import ( Any, Callable, @@ -40,6 +41,7 @@ List, Optional, Tuple, + Type, Union, cast, ) @@ -100,6 +102,9 @@ class PdfReader(PdfDocCommon): password: Decrypt PDF file at initialization. If the password is None, the file will not be decrypted. Defaults to ``None``. + + Can also be instantiated as a contextmanager which will automatically close + the underlying file pointer if passed via filenames. """ def __init__( @@ -123,8 +128,10 @@ def __init__( __name__, ) + self._opened_automatically = False if isinstance(stream, (str, Path)): stream = FileIO(stream, "rb") + self._opened_automatically = True weakref.finalize(self, stream.close) self.read(stream) @@ -160,6 +167,20 @@ def close(self) -> None: """Close the underlying file handle""" self.stream.close() + def __enter__(self) -> "PdfReader": + """Use PdfReader as context manager""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Close the underlying stream if owned by the PdfReader""" + if self._opened_automatically: + self.close() + @property def root_object(self) -> DictionaryObject: """Provide access to "/Root". standardized with PdfWriter.""" diff --git a/tests/test_reader.py b/tests/test_reader.py index 83b61bc59..da093281d 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -853,8 +853,8 @@ def test_extract_text_hello_world(): def test_read_path(): path = Path(RESOURCE_ROOT, "crazyones.pdf") - reader = PdfReader(path) - assert len(reader.pages) == 1 + with PdfReader(path) as reader: + assert len(reader.pages) == 1 def test_read_not_binary_mode(caplog):