From 6d4cd7b71c335fd74b97af17d41ee235bcd0382c Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 15 Nov 2024 12:30:07 +0100 Subject: [PATCH] downloader: before/after request hooks --- .golangci.yml | 2 +- downloader/download.go | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 187f520..325e8e4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -30,7 +30,7 @@ linters-settings: # If lower than 0, disable the check. # Default: 40 # lower this after refactoring - statements: 111 + statements: 115 govet: enable-all: true diff --git a/downloader/download.go b/downloader/download.go index 430e26e..18ab497 100644 --- a/downloader/download.go +++ b/downloader/download.go @@ -88,6 +88,8 @@ type Downloader struct { ifModifiedSince bool lastModified bool compareContent bool + beforeRequest func(*http.Request) + afterRequest func(*http.Response) } // New creates a new downloader for the given URL. @@ -189,6 +191,20 @@ func (d *Downloader) LimitDownloadSize(size int64) *Downloader { return d } +// BeforeRequest sets a function to run before making the HTTP request. +// This can be used to add headers, show user feedback, etc. +func (d *Downloader) BeforeRequest(fn func(*http.Request)) *Downloader { + d.beforeRequest = fn + return d +} + +// AfterRequest sets a function to run after the HTTP request has been made. +// This can be used to check the response, save cookies, etc. +func (d *Downloader) AfterRequest(fn func(*http.Response)) *Downloader { + d.afterRequest = fn + return d +} + // getDestInfo returns the modification time and file mode of the destination file. func (d *Downloader) getDestInfo() (time.Time, fs.FileMode) { dstInfo, err := os.Stat(d.destPath) @@ -504,11 +520,19 @@ func (d *Downloader) Download(ctx context.Context, url string) (bool, error) { d.logger.Trace("If-None-Match: ", etag) } + if d.beforeRequest != nil { + d.beforeRequest(req) + } + resp, err := d.httpClient.Do(req) if err != nil { return false, fmt.Errorf("failed http request for %s: %w", url, err) } + if d.afterRequest != nil { + d.afterRequest(resp) + } + defer resp.Body.Close() switch resp.StatusCode {