From f0e3a08475e58316c2f6141590eba5899a4ef693 Mon Sep 17 00:00:00 2001 From: Miguel Palau Date: Thu, 25 Jul 2024 08:25:25 -0600 Subject: [PATCH] Expose dbsql.connOption type (#202) Fixes: https://github.com/databricks/databricks-sql-go/issues/201 That way we can programmatically add arguments to the `NewConnector` function instead of copy/pasting all of them across conditionals. Signed-off-by: Miguel Palau --- connector.go | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/connector.go b/connector.go index bab81fd..87c73c9 100644 --- a/connector.go +++ b/connector.go @@ -51,7 +51,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { }, CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs, }) - if err != nil { return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err) } @@ -84,11 +83,11 @@ func (c *connector) Driver() driver.Driver { var _ driver.Connector = (*connector)(nil) -type connOption func(*config.Config) +type ConnOption func(*config.Config) // NewConnector creates a connection that can be used with `sql.OpenDB()`. // This is an easier way to set up the DB instead of having to construct a DSN string. -func NewConnector(options ...connOption) (driver.Connector, error) { +func NewConnector(options ...ConnOption) (driver.Connector, error) { // config with default options cfg := config.WithDefaults() cfg.DriverVersion = DriverVersion @@ -102,14 +101,14 @@ func NewConnector(options ...connOption) (driver.Connector, error) { return &connector{cfg: cfg, client: client}, nil } -func withUserConfig(ucfg config.UserConfig) connOption { +func withUserConfig(ucfg config.UserConfig) ConnOption { return func(c *config.Config) { c.UserConfig = ucfg } } // WithServerHostname sets up the server hostname. Mandatory. -func WithServerHostname(host string) connOption { +func WithServerHostname(host string) ConnOption { return func(c *config.Config) { protocol, hostname := parseHostName(host) if protocol != "" { @@ -143,7 +142,7 @@ func parseHostName(host string) (protocol, hostname string) { } // WithPort sets up the server port. Mandatory. -func WithPort(port int) connOption { +func WithPort(port int) ConnOption { return func(c *config.Config) { c.Port = port } @@ -153,7 +152,7 @@ func WithPort(port int) connOption { // By default retryWaitMin = 1 * time.Second // By default retryWaitMax = 30 * time.Second // By default retryMax = 4 -func WithRetries(retryMax int, retryWaitMin time.Duration, retryWaitMax time.Duration) connOption { +func WithRetries(retryMax int, retryWaitMin time.Duration, retryWaitMax time.Duration) ConnOption { return func(c *config.Config) { c.RetryWaitMax = retryWaitMax c.RetryWaitMin = retryWaitMin @@ -162,7 +161,7 @@ func WithRetries(retryMax int, retryWaitMin time.Duration, retryWaitMax time.Dur } // WithAccessToken sets up the Personal Access Token. Mandatory for now. -func WithAccessToken(token string) connOption { +func WithAccessToken(token string) ConnOption { return func(c *config.Config) { if token != "" { c.AccessToken = token @@ -175,7 +174,7 @@ func WithAccessToken(token string) connOption { } // WithHTTPPath sets up the endpoint to the warehouse. Mandatory. -func WithHTTPPath(path string) connOption { +func WithHTTPPath(path string) ConnOption { return func(c *config.Config) { if !strings.HasPrefix(path, "/") { path = "/" + path @@ -185,7 +184,7 @@ func WithHTTPPath(path string) connOption { } // WithMaxRows sets up the max rows fetched per request. Default is 10000 -func WithMaxRows(n int) connOption { +func WithMaxRows(n int) ConnOption { return func(c *config.Config) { if n != 0 { c.MaxRows = n @@ -194,7 +193,7 @@ func WithMaxRows(n int) connOption { } // WithTimeout adds timeout for the server query execution. Default is no timeout. -func WithTimeout(n time.Duration) connOption { +func WithTimeout(n time.Duration) ConnOption { return func(c *config.Config) { c.QueryTimeout = n } @@ -202,7 +201,7 @@ func WithTimeout(n time.Duration) connOption { // Sets the initial catalog name and schema name in the session. // Use -func WithInitialNamespace(catalog, schema string) connOption { +func WithInitialNamespace(catalog, schema string) ConnOption { return func(c *config.Config) { c.Catalog = catalog c.Schema = schema @@ -210,7 +209,7 @@ func WithInitialNamespace(catalog, schema string) connOption { } // Used to identify partners. Set as a string with format . -func WithUserAgentEntry(entry string) connOption { +func WithUserAgentEntry(entry string) ConnOption { return func(c *config.Config) { c.UserAgentEntry = entry } @@ -218,7 +217,7 @@ func WithUserAgentEntry(entry string) connOption { // Sessions params will be set upon opening the session by calling SET function. // If using connection pool, session params can avoid successive calls of "SET ..." -func WithSessionParams(params map[string]string) connOption { +func WithSessionParams(params map[string]string) ConnOption { return func(c *config.Config) { for k, v := range params { if strings.ToLower(k) == "timezone" { @@ -227,7 +226,6 @@ func WithSessionParams(params map[string]string) connOption { } else { c.Location = loc } - } } c.SessionParams = params @@ -249,35 +247,35 @@ func WithSkipTLSHostVerify() connOption { } // WithAuthenticator sets up the Authentication. Mandatory if access token is not provided. -func WithAuthenticator(authr auth.Authenticator) connOption { +func WithAuthenticator(authr auth.Authenticator) ConnOption { return func(c *config.Config) { c.Authenticator = authr } } // WithTransport sets up the transport configuration to be used by the httpclient. -func WithTransport(t http.RoundTripper) connOption { +func WithTransport(t http.RoundTripper) ConnOption { return func(c *config.Config) { c.Transport = t } } // WithCloudFetch sets up the use of cloud fetch for query execution. Default is false. -func WithCloudFetch(useCloudFetch bool) connOption { +func WithCloudFetch(useCloudFetch bool) ConnOption { return func(c *config.Config) { c.UseCloudFetch = useCloudFetch } } // WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10. -func WithMaxDownloadThreads(numThreads int) connOption { +func WithMaxDownloadThreads(numThreads int) ConnOption { return func(c *config.Config) { c.MaxDownloadThreads = numThreads } } // Setup of Oauth M2m authentication -func WithClientCredentials(clientID, clientSecret string) connOption { +func WithClientCredentials(clientID, clientSecret string) ConnOption { return func(c *config.Config) { if clientID != "" && clientSecret != "" { authr := m2m.NewAuthenticator(clientID, clientSecret, c.Host)