-
Notifications
You must be signed in to change notification settings - Fork 13
/
readurl.go
130 lines (112 loc) · 2.94 KB
/
readurl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package forwarder
import (
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
)
// ReadURLString can read base64 encoded data, local file, http or https URL or stdin and return it as a string.
func ReadURLString(u *url.URL, rt http.RoundTripper) (string, error) {
b, err := ReadURL(u, rt)
if err != nil {
return "", err
}
return string(b), nil
}
// ReadURL can read base64 encoded data, local file, http or https URL or stdin.
func ReadURL(u *url.URL, rt http.RoundTripper) ([]byte, error) {
switch u.Scheme {
case "data":
return readData(u)
case "file":
return readFile(u)
case "http", "https":
return readHTTP(u, rt)
default:
return nil, fmt.Errorf("unsupported scheme %q, supported schemes are: file, http and https", u.Scheme)
}
}
func readData(u *url.URL) ([]byte, error) {
v := strings.TrimPrefix(u.Opaque, "//")
idx := strings.IndexByte(v, ',')
if idx != -1 {
if v[:idx] != "base64" {
return nil, errors.New("invalid data URI, the only supported format is: data:base64,<encoded data>")
}
v = v[idx+1:]
}
b, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return nil, err
}
return b, nil
}
func readFile(u *url.URL) ([]byte, error) {
if u.Host != "" {
return nil, fmt.Errorf("invalid file URL %q, host is not allowed", u.String())
}
if u.User != nil {
return nil, fmt.Errorf("invalid file URL %q, user is not allowed", u.String())
}
if u.RawQuery != "" {
return nil, fmt.Errorf("invalid file URL %q, query is not allowed", u.String())
}
if u.Fragment != "" {
return nil, fmt.Errorf("invalid file URL %q, fragment is not allowed", u.String())
}
if u.Path == "" {
return nil, fmt.Errorf("invalid file URL %q, path is empty", u.String())
}
if u.Path == "-" {
return readAndClose(os.Stdin)
}
f, err := os.Open(u.Path)
if err != nil {
return nil, err
}
return readAndClose(f)
}
func readAndClose(r io.ReadCloser) ([]byte, error) {
defer r.Close()
return io.ReadAll(r)
}
func readHTTP(u *url.URL, rt http.RoundTripper) ([]byte, error) {
c := http.Client{
Transport: rt,
}
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) //nolint:noctx // timeout is set in the transport
if err != nil {
return nil, err
}
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return b, nil
}
func ReadFileOrBase64(name string) ([]byte, error) {
if strings.HasPrefix(name, "data:") {
return readData(&url.URL{
Scheme: "data",
Opaque: name[5:],
})
}
return os.ReadFile(name)
}