Skip to content

Commit

Permalink
Merge pull request #267 from taosdata/enh/xftan/TS-4816-3.1
Browse files Browse the repository at this point in the history
enh: support skip ssl verify
  • Loading branch information
huskar-t authored May 22, 2024
2 parents 1a113a6 + 523d7a2 commit d2ee5d0
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 21 deletions.
29 changes: 18 additions & 11 deletions taosRestful/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package taosRestful
import (
"compress/gzip"
"context"
"crypto/tls"
"database/sql/driver"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -36,18 +37,24 @@ func newTaosConn(cfg *config) (*taosConn, error) {
readBufferSize = 4 << 10
}
tc := &taosConn{cfg: cfg, readBufferSize: readBufferSize}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: cfg.disableCompression,
}
if cfg.skipVerify {
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
tc.client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: cfg.disableCompression,
},
Transport: transport,
}
path := "/rest/sql"
if len(cfg.dbName) != 0 {
Expand Down
153 changes: 153 additions & 0 deletions taosRestful/connector_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
package taosRestful

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
crand "crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"encoding/pem"
"fmt"
"log"
"math/big"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -377,3 +390,143 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) {
assert.Nil(t, *values[i].(*interface{}))
}
}

func generateSelfSignedCert() (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), crand.Reader)
if err != nil {
return tls.Certificate{}, err
}

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)

serialNumber, err := crand.Int(crand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return tls.Certificate{}, err
}

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Your Company"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

certDER, err := x509.CreateCertificate(crand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyPEM, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}

keyPEMBlock := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyPEM})

return tls.X509KeyPair(certPEM, keyPEMBlock)
}

func startProxy() *http.Server {
// Generate self-signed certificate
cert, err := generateSelfSignedCert()
if err != nil {
log.Fatalf("Failed to generate self-signed certificate: %v", err)
}

target := "http://127.0.0.1:6041"
proxyURL, err := url.Parse(target)
if err != nil {
log.Fatalf("Failed to parse target URL: %v", err)
}

proxy := httputil.NewSingleHostReverseProxy(proxyURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, e error) {
http.Error(w, "Proxy error", http.StatusBadGateway)
}
mux := http.NewServeMux()
mux.Handle("/", proxy)

server := &http.Server{
Addr: ":34443",
Handler: mux,
TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}},
// Setup server timeouts for better handling of idle connections and slowloris attacks
WriteTimeout: 10 * time.Second,
ReadTimeout: 10 * time.Second,
IdleTimeout: 30 * time.Second,
}

log.Println("Starting server on :34443")
go func() {
err = server.ListenAndServeTLS("", "")
if err != nil && err != http.ErrServerClosed {
log.Fatalf("Failed to start HTTPS server: %v", err)
}
}()
return server
}

func TestSSL(t *testing.T) {
dataSourceNameWithSkipVerify := fmt.Sprintf("%s:%s@https(%s:%d)/?skipVerify=true", user, password, host, 34443)
server := startProxy()
defer server.Shutdown(context.Background())
time.Sleep(1 * time.Second)
database := "restful_test_ssl"
db, err := sql.Open("taosRestful", dataSourceNameWithSkipVerify)
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Ping()
if err != nil {
t.Fatal(err)
}
defer func() {
_, err = db.Exec(fmt.Sprintf("drop database if exists %s", database))
if err != nil {
t.Fatal(err)
}
}()
_, err = db.Exec(fmt.Sprintf("create database if not exists %s", database))
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(generateCreateTableSql(database, true))
if err != nil {
t.Fatal(err)
}
colValues, scanValues, insertSql := generateValues()
_, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql))
if err != nil {
t.Fatal(err)
}
rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano)))
assert.NoError(t, err)
columns, err := rows.Columns()
assert.NoError(t, err)
t.Log(columns)
cTypes, err := rows.ColumnTypes()
assert.NoError(t, err)
t.Log(cTypes)
var tt types.RawMessage
dest := make([]interface{}, len(scanValues)+1)
for i := range scanValues {
dest[i] = reflect.ValueOf(&scanValues[i]).Interface()
}
dest[len(scanValues)] = &tt
for rows.Next() {
err := rows.Scan(dest...)
assert.NoError(t, err)
}
for i, v := range colValues {
assert.Equal(t, v, scanValues[i])
}
assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt)
}
6 changes: 6 additions & 0 deletions taosRestful/dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type config struct {
disableCompression bool
readBufferSize int
token string // cloud platform token
skipVerify bool
}

// NewConfig creates a new Config and sets default values.
Expand Down Expand Up @@ -154,6 +155,11 @@ func parseDSNParams(cfg *config, params string) (err error) {
}
case "token":
cfg.token = value
case "skipVerify":
cfg.skipVerify, err = strconv.ParseBool(value)
if err != nil {
return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value}
}
default:
// lazy init
if cfg.params == nil {
Expand Down
24 changes: 14 additions & 10 deletions taosRestful/dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ import (
// @description: test parse dsn
func TestParseDsn(t *testing.T) {
tcs := []struct {
dsn string
errs string
user string
passwd string
net string
addr string
port int
dbName string
token string
dsn string
errs string
user string
passwd string
net string
addr string
port int
dbName string
token string
skipVerify bool
}{{},
{dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"},
{dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"},
Expand All @@ -28,6 +29,7 @@ func TestParseDsn(t *testing.T) {
{dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"},
{dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"},
{dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"},
{dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true},
}
for i, tc := range tcs {
name := fmt.Sprintf("%d - %s", i, tc.dsn)
Expand All @@ -45,7 +47,9 @@ func TestParseDsn(t *testing.T) {
cfg.passwd != tc.passwd ||
cfg.net != tc.net ||
cfg.addr != tc.addr ||
cfg.port != tc.port {
cfg.port != tc.port ||
cfg.token != tc.token ||
cfg.skipVerify != tc.skipVerify {
t.Fatal(cfg)
}
})
Expand Down

0 comments on commit d2ee5d0

Please sign in to comment.