Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement authntication #10

Merged
merged 4 commits into from
Jan 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,28 @@ You can also execute a query with Session, Transaction and Aliases
session := "de131c80-84c0-417f-abdf-29ad781a7d04" //use UUID generator
data, err := gremlin.Query(`g.V().has("name", userName).valueMap()`).Bindings(gremlin.Bind{"userName": "john"}).Session(session).ManageTransaction(true).SetProcessor("session").Aliases(aliases).Exec()
```

Authentication
===
For authentication, you can set environment variables `GREMLIN_USER` and `GREMLIN_PASS` and create a `Client`, passing functional parameter `OptAuthEnv`

```go
auth := gremlin.OptAuthEnv()
client, err := gremlin.NewClient("ws://remote.example.com:443/gremlin", auth)
data, err = client.ExecQuery(`g.V()`)
if err != nil {
panic(err)
}
doStuffWith(data)
```

If you don't like environment variables you can authenticate passing username and password string in the following way:
```go
auth := gremlin.OptAuthUserPass("myusername", "mypass")
client, err := gremlin.NewClient("ws://remote.example.com:443/gremlin", auth)
data, err = client.ExecQuery(`g.V()`)
if err != nil {
panic(err)
}
doStuffWith(data)
```
169 changes: 169 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,183 @@
package gremlin

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gorilla/websocket"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)

// Clients include the necessary info to connect to the server and the underlying socket
type Client struct {
Remote *url.URL
Ws *websocket.Conn
Auth []OptAuth
}

func NewClient(urlStr string, options ...OptAuth) (*Client, error) {
r, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
dialer := websocket.Dialer{}
ws, _, err := dialer.Dial(urlStr, http.Header{})
if err != nil {
return nil, err
}
return &Client{Remote: r, Ws: ws, Auth: options}, nil
}

// Client executes the provided request
func (c *Client) ExecQuery(query string) ([]byte, error) {
req := Query(query)
return c.Exec(req)
}

func (c *Client) Exec(req *Request) ([]byte, error) {
requestMessage, err := GraphSONSerializer(req)
if err != nil {
return nil, err
}
fmt.Println(string(requestMessage))
// Open a TCP connection
if err = c.Ws.WriteMessage(websocket.BinaryMessage, requestMessage); err != nil {
print("error", err)
return nil, err
}
return c.ReadResponse()
}

func (c *Client) ReadResponse() (data []byte, err error) {
// Data buffer
var message []byte
var dataItems []json.RawMessage
inBatchMode := false
// Receive data
for {
if _, message, err = c.Ws.ReadMessage(); err != nil {
return
}
var res *Response
if err = json.Unmarshal(message, &res); err != nil {
return
}
var items []json.RawMessage
switch res.Status.Code {
case StatusNoContent:
return

case StatusAuthenticate:
return c.Authenticate(res.RequestId)
case StatusPartialContent:
inBatchMode = true
if err = json.Unmarshal(res.Result.Data, &items); err != nil {
return
}
dataItems = append(dataItems, items...)

case StatusSuccess:
if inBatchMode {
if err = json.Unmarshal(res.Result.Data, &items); err != nil {
return
}
dataItems = append(dataItems, items...)
data, err = json.Marshal(dataItems)
} else {
data = res.Result.Data
}
return

default:
fmt.Println(res)
if msg, exists := ErrorMsg[res.Status.Code]; exists {
err = errors.New(msg)
} else {
err = errors.New("An unknown error occured")
}
return
}
}
return
}

// AuthInfo includes all info related with SASL authentication with the Gremlin server
// ChallengeId is the requestID in the 407 status (AUTHENTICATE) response given by the server.
// We have to send an authentication request with that same RequestID in order to solve the challenge.
type AuthInfo struct {
ChallengeId string
User string
Pass string
}

type OptAuth func(*AuthInfo) error

// Constructor for different authentication possibilities
func NewAuthInfo(options ...OptAuth) (*AuthInfo, error) {
auth := &AuthInfo{}
for _, op := range options {
err := op(auth)
if err != nil {
return nil, err
}
}
return auth, nil
}

// Sets authentication info from environment variables GREMLIN_USER and GREMLIN_PASS
func OptAuthEnv() OptAuth {
return func(auth *AuthInfo) error {
user, ok := os.LookupEnv("GREMLIN_USER")
if !ok {
return errors.New("Variable GREMLIN_USER is not set")
}
pass, ok := os.LookupEnv("GREMLIN_PASS")
if !ok {
return errors.New("Variable GREMLIN_PASS is not set")
}
auth.User = user
auth.Pass = pass
return nil
}
}

// Sets authentication information from username and password
func OptAuthUserPass(user, pass string) OptAuth {
return func(auth *AuthInfo) error {
auth.User = user
auth.Pass = pass
return nil
}
}

// Authenticates the connection
func (c *Client) Authenticate(requestId string) ([]byte, error) {
auth, err := NewAuthInfo(c.Auth...)
if err != nil {
return nil, err
}
var sasl []byte
sasl = append(sasl, 0)
sasl = append(sasl, []byte(auth.User)...)
sasl = append(sasl, 0)
sasl = append(sasl, []byte(auth.Pass)...)
saslEnc := base64.StdEncoding.EncodeToString(sasl)
args := &RequestArgs{Sasl: saslEnc}
authReq := &Request{
RequestId: requestId,
Processor: "trasversal",
Op: "authentication",
Args: args,
}
return c.Exec(authReq)
}

var servers []*url.URL

func NewCluster(s ...string) (err error) {
Expand Down
33 changes: 32 additions & 1 deletion request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package gremlin

import (
"encoding/json"
_ "fmt"
"github.com/satori/go.uuid"
)

Expand All @@ -17,12 +19,41 @@ type RequestArgs struct {
Bindings Bind `json:"bindings,omitempty"`
Language string `json:"language,omitempty"`
Rebindings Bind `json:"rebindings,omitempty"`
Sasl []byte `json:"sasl,omitempty"`
Sasl string `json:"sasl,omitempty"`
BatchSize int `json:"batchSize,omitempty"`
ManageTransaction bool `json:"manageTransaction,omitempty"`
Aliases map[string]string `json:"aliases,omitempty"`
}

// Formats the requests in the appropriate way
type FormattedReq struct {
Op string `json:"op"`
RequestId interface{} `json:"requestId"`
Args *RequestArgs `json:"args"`
Processor string `json:"processor"`
}

func GraphSONSerializer(req *Request) ([]byte, error) {
form := NewFormattedReq(req)
msg, err := json.Marshal(form)
if err != nil {
return nil, err
}
mimeType := []byte("application/vnd.gremlin-v2.0+json")
var mimeLen = []byte{0x21}
res := append(mimeLen, mimeType...)
res = append(res, msg...)
return res, nil

}

func NewFormattedReq(req *Request) FormattedReq {
rId := map[string]string{"@type": "g:UUID", "@value": req.RequestId}
sr := FormattedReq{RequestId: rId, Processor: req.Processor, Op: req.Op, Args: req.Args}

return sr
}

type Bind map[string]interface{}

func Query(query string) *Request {
Expand Down
93 changes: 4 additions & 89 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package gremlin

import (
"encoding/json"
"errors"
"github.com/gorilla/websocket"
"net/http"
"time"
"fmt"
)

type Response struct {
Expand All @@ -25,89 +22,7 @@ type ResponseResult struct {
Meta map[string]interface{} `json:"meta"`
}

func ReadResponse(ws *websocket.Conn) (data []byte, err error) {
// Data buffer
var message []byte
var dataItems []json.RawMessage
inBatchMode := false
// Receive data
for {
if _, message, err = ws.ReadMessage(); err != nil {
return
}
var res *Response
if err = json.Unmarshal(message, &res); err != nil {
return
}
var items []json.RawMessage
switch res.Status.Code {
case StatusNoContent:
return

case StatusPartialContent:
inBatchMode = true
if err = json.Unmarshal(res.Result.Data, &items); err != nil {
return
}
dataItems = append(dataItems, items...)

case StatusSuccess:
if inBatchMode {
if err = json.Unmarshal(res.Result.Data, &items); err != nil {
return
}
dataItems = append(dataItems, items...)
data, err = json.Marshal(dataItems)
} else {
data = res.Result.Data
}
return

default:
if msg, exists := ErrorMsg[res.Status.Code]; exists {
err = errors.New(msg)
} else {
err = errors.New("An unknown error occured")
}
return
}
}
return
}

func (req *Request) Exec() (data []byte, err error) {
// Prepare the Data
message, err := json.Marshal(req)
if err != nil {
return
}
// Prepare the request message
var requestMessage []byte
mimeType := []byte("application/json")
mimeTypeLen := byte(len(mimeType))
requestMessage = append(requestMessage, mimeTypeLen)
requestMessage = append(requestMessage, mimeType...)
requestMessage = append(requestMessage, message...)
// Open a TCP connection
conn, server, err := CreateConnection()
if err != nil {
return
}
// Open a new socket connection
ws, _, err := websocket.NewClient(conn, server, http.Header{}, 0, len(requestMessage))
if err != nil {
return
}
defer ws.Close()
if err = ws.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
return
}
if err = ws.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
return
}
if err = ws.WriteMessage(websocket.BinaryMessage, requestMessage); err != nil {
return
}

return ReadResponse(ws)
// Implementation of the stringer interface. Useful for exploration
func (r Response) String() string {
return fmt.Sprintf("Response \nRequestId: %v, \nStatus: {%#v}, \nResult: {%#v}\n", r.RequestId, r.Status, r.Result)
}