diff --git a/src/pinnwand/defensive.py b/src/pinnwand/defensive.py index 20323dc..0b64d53 100644 --- a/src/pinnwand/defensive.py +++ b/src/pinnwand/defensive.py @@ -22,7 +22,9 @@ ] = {} -def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool: +def should_be_ratelimited( + request: HTTPServerRequest, area: str = "global" +) -> bool: """Test if a requesting IP is ratelimited for a certain area. Areas are different functionalities of the website, for example 'view' or 'input' to differentiate between creating new pastes (low volume) or high volume @@ -55,13 +57,13 @@ def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool: return False -def ratelimit_endpoint(area: str): +def ratelimit(area: str): """A ratelimiting decorator for tornado's request handlers.""" def wrapper(func): @wraps(func) def inner(request_handler: RequestHandler, *args, **kwargs): - if ratelimit(request_handler.request, area): + if should_be_ratelimited(request_handler.request, area): raise error.RatelimitError() return func(request_handler, *args, **kwargs) diff --git a/src/pinnwand/handler/api_curl.py b/src/pinnwand/handler/api_curl.py index 813c328..1fca553 100644 --- a/src/pinnwand/handler/api_curl.py +++ b/src/pinnwand/handler/api_curl.py @@ -29,7 +29,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None: else: super().write_error(status_code, **kwargs) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") def post(self) -> None: configuration: Configuration = ConfigurationProvider.get_config() diff --git a/src/pinnwand/handler/api_deprecated.py b/src/pinnwand/handler/api_deprecated.py index 8fe98de..f2fae8c 100644 --- a/src/pinnwand/handler/api_deprecated.py +++ b/src/pinnwand/handler/api_deprecated.py @@ -75,7 +75,7 @@ async def post(self) -> None: class Show(Base): """Show a paste on the deprecated API.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: # type: ignore with manager.DatabaseManager.get_session() as session: paste = ( @@ -119,9 +119,8 @@ def check_xsrf_cookie(self) -> None: async def get(self) -> None: raise tornado.web.HTTPError(405) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") async def post(self) -> None: - configuration: Configuration = ConfigurationProvider.get_config() lexer = self.get_body_argument("lexer") @@ -175,9 +174,8 @@ def check_xsrf_cookie(self) -> None: """No XSRF cookies on the API.""" return - @defensive.ratelimit_endpoint(area="delete") + @defensive.ratelimit(area="delete") async def post(self) -> None: - with manager.DatabaseManager.get_session() as session: paste = ( session.query(models.Paste) @@ -207,7 +205,7 @@ async def post(self) -> None: class Lexer(Base): """List lexers through the deprecated API.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: self.write(utility.list_languages()) @@ -215,7 +213,7 @@ async def get(self) -> None: class Expiry(Base): """List expiries through the deprecated API.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: configuration: Configuration = ConfigurationProvider.get_config() diff --git a/src/pinnwand/handler/api_v1.py b/src/pinnwand/handler/api_v1.py index b1b46b7..2980587 100644 --- a/src/pinnwand/handler/api_v1.py +++ b/src/pinnwand/handler/api_v1.py @@ -19,15 +19,13 @@ def write_error(self, status_code: int, **kwargs: Any) -> None: class Lexer(Base): - - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: self.write(utility.list_languages()) class Expiry(Base): - - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: configuration: Configuration = ConfigurationProvider.get_config() @@ -46,9 +44,8 @@ def check_xsrf_cookie(self) -> None: async def get(self) -> None: raise tornado.web.HTTPError(405) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") async def post(self) -> None: - try: data = tornado.escape.json_decode(self.request.body) except json.decoder.JSONDecodeError: @@ -127,8 +124,7 @@ async def post(self) -> None: class PasteDetail(Base): - - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: with manager.DatabaseManager.get_session() as session: paste = ( diff --git a/src/pinnwand/handler/website.py b/src/pinnwand/handler/website.py index 387b51a..a5c1ed3 100644 --- a/src/pinnwand/handler/website.py +++ b/src/pinnwand/handler/website.py @@ -82,7 +82,7 @@ class Create(Base): """The index page shows the new paste page with a list of all available lexers from Pygments.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, lexers: str = "") -> None: """Render the new paste form, optionally have a lexer preselected from the URL.""" @@ -112,7 +112,7 @@ async def get(self, lexers: str = "") -> None: paste=None, ) - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") async def post(self) -> None: """This is a historical endpoint to create pastes, pastes are marked as old-web and will get a warning on top of them to remove any access to @@ -174,7 +174,7 @@ class CreateAction(Base): """The create action is the 'new' way to create pastes and supports multi file pastes.""" - @defensive.ratelimit_endpoint(area="create") + @defensive.ratelimit(area="create") def post(self) -> None: # type: ignore """POST handler for the 'web' side of things.""" @@ -260,7 +260,7 @@ class Repaste(Base): """Repaste is a specific case of the paste page. It only works for pre- existing pastes and will prefill the textarea and lexer.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: # type: ignore """Render the new paste form, optionally have a lexer preselected from the URL.""" @@ -293,7 +293,7 @@ async def get(self, slug: str) -> None: # type: ignore class Show(Base): """Show a paste.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, slug: str) -> None: # type: ignore """Fetch paste from database by slug and render the paste.""" @@ -360,7 +360,7 @@ async def get(self, slug: str) -> None: # type: ignore class FileRaw(Base): """Show a file as plaintext.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, file_id: str) -> None: # type: ignore """Get a file from the database and show it in the plain.""" @@ -391,7 +391,7 @@ async def get(self, file_id: str) -> None: # type: ignore class FileHex(Base): """Show a file as hexadecimal.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, file_id: str) -> None: # type: ignore """Get a file from the database and show it in hex.""" @@ -422,7 +422,7 @@ async def get(self, file_id: str) -> None: # type: ignore class PasteDownload(Base): """Download an entire paste.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, paste_id: str) -> None: # type: ignore """Get all files from the database and download them as a zipfile.""" @@ -469,7 +469,7 @@ async def get(self, paste_id: str) -> None: # type: ignore class FileDownload(Base): """Download a file.""" - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self, file_id: str) -> None: # type: ignore """Get a file from the database and download it in the plain.""" @@ -511,7 +511,7 @@ async def get(self, file_id: str) -> None: # type: ignore class Remove(Base): """Remove a paste.""" - @defensive.ratelimit_endpoint(area="delete") + @defensive.ratelimit(area="delete") async def get(self, removal: str) -> None: # type: ignore """Look up if the user visiting this page has the removal id for a certain paste. If they do they're authorized to remove the paste.""" @@ -549,7 +549,7 @@ class RestructuredTextPage(Base): def initialize(self, file: str) -> None: self.file = file - @defensive.ratelimit_endpoint(area="read") + @defensive.ratelimit(area="read") async def get(self) -> None: try: with open(path.page / self.file) as f: @@ -573,7 +573,6 @@ def initialize(self, path: str) -> None: self.path = path async def get(self) -> None: - try: with open(self.path, "rb") as f: self.write(f.read())