diff --git a/taosRestful/connection.go b/taosRestful/connection.go index e7f21c5..4970184 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -3,6 +3,7 @@ package taosRestful import ( "compress/gzip" "context" + "crypto/tls" "database/sql/driver" "encoding/base64" "errors" @@ -36,18 +37,24 @@ func newTaosConn(cfg *config) (*taosConn, error) { readBufferSize = 4 << 10 } tc := &taosConn{cfg: cfg, readBufferSize: readBufferSize} + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: cfg.disableCompression, + } + if cfg.skipVerify { + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } tc.client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableCompression: cfg.disableCompression, - }, + Transport: transport, } path := "/rest/sql" if len(cfg.dbName) != 0 { diff --git a/taosRestful/connector_test.go b/taosRestful/connector_test.go index f8bf6f7..eac38db 100644 --- a/taosRestful/connector_test.go +++ b/taosRestful/connector_test.go @@ -1,9 +1,22 @@ package taosRestful import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + crand "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "database/sql" + "encoding/pem" "fmt" + "log" + "math/big" "math/rand" + "net/http" + "net/http/httputil" + "net/url" "reflect" "strings" "testing" @@ -377,3 +390,143 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { assert.Nil(t, *values[i].(*interface{})) } } + +func generateSelfSignedCert() (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), crand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumber, err := crand.Int(crand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Your Company"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(crand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + + keyPEMBlock := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyPEM}) + + return tls.X509KeyPair(certPEM, keyPEMBlock) +} + +func startProxy() *http.Server { + // Generate self-signed certificate + cert, err := generateSelfSignedCert() + if err != nil { + log.Fatalf("Failed to generate self-signed certificate: %v", err) + } + + target := "http://127.0.0.1:6041" + proxyURL, err := url.Parse(target) + if err != nil { + log.Fatalf("Failed to parse target URL: %v", err) + } + + proxy := httputil.NewSingleHostReverseProxy(proxyURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, e error) { + http.Error(w, "Proxy error", http.StatusBadGateway) + } + mux := http.NewServeMux() + mux.Handle("/", proxy) + + server := &http.Server{ + Addr: ":34443", + Handler: mux, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, + // Setup server timeouts for better handling of idle connections and slowloris attacks + WriteTimeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + IdleTimeout: 30 * time.Second, + } + + log.Println("Starting server on :34443") + go func() { + err = server.ListenAndServeTLS("", "") + if err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start HTTPS server: %v", err) + } + }() + return server +} + +func TestSSL(t *testing.T) { + dataSourceNameWithSkipVerify := fmt.Sprintf("%s:%s@https(%s:%d)/?skipVerify=true", user, password, host, 34443) + server := startProxy() + defer server.Shutdown(context.Background()) + time.Sleep(1 * time.Second) + database := "restful_test_ssl" + db, err := sql.Open("taosRestful", dataSourceNameWithSkipVerify) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) + if err != nil { + t.Fatal(err) + } + _, err = db.Exec(generateCreateTableSql(database, true)) + if err != nil { + t.Fatal(err) + } + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt + for rows.Next() { + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) +} diff --git a/taosRestful/dsn.go b/taosRestful/dsn.go index 3936231..612962a 100644 --- a/taosRestful/dsn.go +++ b/taosRestful/dsn.go @@ -30,6 +30,7 @@ type config struct { disableCompression bool readBufferSize int token string // cloud platform token + skipVerify bool } // NewConfig creates a new Config and sets default values. @@ -154,6 +155,11 @@ func parseDSNParams(cfg *config, params string) (err error) { } case "token": cfg.token = value + case "skipVerify": + cfg.skipVerify, err = strconv.ParseBool(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} + } default: // lazy init if cfg.params == nil { diff --git a/taosRestful/dsn_test.go b/taosRestful/dsn_test.go index e71b813..2ec3fbd 100644 --- a/taosRestful/dsn_test.go +++ b/taosRestful/dsn_test.go @@ -10,15 +10,16 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tcs := []struct { - dsn string - errs string - user string - passwd string - net string - addr string - port int - dbName string - token string + dsn string + errs string + user string + passwd string + net string + addr string + port int + dbName string + token string + skipVerify bool }{{}, {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, {dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"}, @@ -28,6 +29,7 @@ func TestParseDsn(t *testing.T) { {dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"}, {dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"}, {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"}, + {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true}, } for i, tc := range tcs { name := fmt.Sprintf("%d - %s", i, tc.dsn) @@ -45,7 +47,9 @@ func TestParseDsn(t *testing.T) { cfg.passwd != tc.passwd || cfg.net != tc.net || cfg.addr != tc.addr || - cfg.port != tc.port { + cfg.port != tc.port || + cfg.token != tc.token || + cfg.skipVerify != tc.skipVerify { t.Fatal(cfg) } })