diff --git a/src/result/result.py b/src/result/result.py index 30bf9a8..4943321 100644 --- a/src/result/result.py +++ b/src/result/result.py @@ -15,9 +15,11 @@ Iterator, Literal, NoReturn, + Optional, Type, TypeVar, Union, + overload, ) if sys.version_info >= (3, 10): @@ -138,7 +140,13 @@ def unwrap_or_else(self, op: object) -> T: """ return self._value - def unwrap_or_raise(self, e: object) -> T: + @overload + def unwrap_or_raise(self) -> T: + ... + @overload + def unwrap_or_raise(self, e: Type[TBE]) -> T: + ... + def unwrap_or_raise(self, e: Optional[Type[TBE]] = None) -> T: """ Return the value. """ @@ -350,11 +358,22 @@ def unwrap_or_else(self, op: Callable[[E], T]) -> T: """ return op(self._value) + + @overload + def unwrap_or_raise(self) -> NoReturn: + ... + @overload def unwrap_or_raise(self, e: Type[TBE]) -> NoReturn: + ... + def unwrap_or_raise(self, e: Optional[Type[TBE]] = None) -> NoReturn: """ The contained result is ``Err``, so raise the exception with the value. """ - raise e(self._value) + if e is not None: + raise e(self._value) + if isinstance(self._value, BaseException): + raise self._value + self.unwrap() def map(self, op: object) -> Err[E]: """ diff --git a/tests/test_result.py b/tests/test_result.py index c99c241..b8de404 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -157,6 +157,10 @@ def test_unwrap_or_raise() -> None: n.unwrap_or_raise(ValueError) assert exc_info.value.args == ('nay',) + n2 = Err(ValueError('nay')) + with pytest.raises(ValueError) as exc_info: + n2.unwrap_or_raise() + def test_map() -> None: o = Ok('yay')