diff --git a/src/result/__init__.py b/src/result/__init__.py index 93cdcc9..ec3106e 100644 --- a/src/result/__init__.py +++ b/src/result/__init__.py @@ -6,6 +6,7 @@ UnwrapError, as_async_result, as_result, + returns_result, is_ok, is_err, do, @@ -20,6 +21,7 @@ "UnwrapError", "as_async_result", "as_result", + "returns_result", "is_ok", "is_err", "do", diff --git a/src/result/result.py b/src/result/result.py index 8551239..6b24984 100644 --- a/src/result/result.py +++ b/src/result/result.py @@ -146,6 +146,12 @@ def unwrap_or_raise(self, e: object) -> T: """ return self._value + def unwrap_or_return(self) -> T: + """ + Return the value. + """ + return self._value + def map(self, op: Callable[[T], U]) -> Ok[U]: """ The contained result is `Ok`, so return `Ok` with original value mapped to @@ -358,6 +364,12 @@ def unwrap_or_raise(self, e: Type[TBE]) -> NoReturn: """ raise e(self._value) + def unwrap_or_return(self) -> NoReturn: + """ + The contained result is ``Err``, raise DoException with self. + """ + raise DoException(self) + def map(self, op: object) -> Err[E]: """ Return `Err` with the same value @@ -464,6 +476,24 @@ def result(self) -> Result[Any, Any]: return self._result +def returns_result() -> Callable[[Callable[P, Result[R, E]]], Callable[P, Result[R, E]]]: + """ + Make a decorator to turn a function into one that allows using unwrap_or_return. + """ + def decorator(f: Callable[P, Result[R, E]]) -> Callable[P, Result[R, E]]: + @functools.wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[R, E]: + try: + return f(*args, **kwargs) + except DoException as e: + out: Err[E] = e.err # type: ignore + return out + + return wrapper + + return decorator + + def as_result( *exceptions: Type[TBE], ) -> Callable[[Callable[P, R]], Callable[P, Result[R, TBE]]]: diff --git a/tests/test_result.py b/tests/test_result.py index c99c241..11a0aaa 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -4,7 +4,7 @@ import pytest -from result import Err, Ok, OkErr, Result, UnwrapError, as_async_result, as_result +from result import Err, Ok, OkErr, Result, UnwrapError, as_async_result, as_result, returns_result def test_ok_factories() -> None: @@ -158,6 +158,22 @@ def test_unwrap_or_raise() -> None: assert exc_info.value.args == ('nay',) +def test_unwrap_or_return() -> None: + @returns_result() + def func(yay: bool) -> Result[str, str]: + if yay: + o = Ok('yay') + value = o.unwrap_or_return() + return Ok(value) + else: + n = Err('nay') + value = n.unwrap_or_return() + assert False + + assert func(True).ok() == 'yay' + assert func(False).err() == 'nay' + + def test_map() -> None: o = Ok('yay') n = Err('nay')