diff --git a/cmd/start.go b/cmd/start.go index 3e1bb15..19bc922 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -17,6 +17,9 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" rpchttp "github.com/tendermint/tendermint/rpc/client/http" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" "github.com/sentinel-official/dvpn-node/context" "github.com/sentinel-official/dvpn-node/lite" @@ -34,8 +37,9 @@ func StartCmd() *cobra.Command { Short: "Start VPN node", RunE: func(cmd *cobra.Command, _ []string) error { var ( - home = viper.GetString(flags.FlagHome) - path = filepath.Join(home, types.ConfigFileName) + home = viper.GetString(flags.FlagHome) + configPath = filepath.Join(home, types.ConfigFileName) + databasePath = filepath.Join(home, types.DatabaseFileName) ) log, err := utils.PrepareLogger() @@ -44,9 +48,9 @@ func StartCmd() *cobra.Command { } v := viper.New() - v.SetConfigFile(path) + v.SetConfigFile(configPath) - log.Info("Reading the configuration file", "path", path) + log.Info("Reading the configuration file", "path", configPath) cfg, err := types.ReadInConfig(v) if err != nil { return err @@ -158,6 +162,22 @@ func StartCmd() *cobra.Command { return err } + log.Info("Opening the database", "path", databasePath) + database, err := gorm.Open( + sqlite.Open(databasePath), + &gorm.Config{ + Logger: logger.Discard, + }, + ) + if err != nil { + return err + } + + log.Info("Migrating database models...") + if err := database.AutoMigrate(&types.Session{}); err != nil { + return err + } + var ( ctx = context.NewContext() router = mux.NewRouter() @@ -172,7 +192,7 @@ func StartCmd() *cobra.Command { WithConfig(cfg). WithClient(client). WithLocation(location). - WithSessions(types.NewSessions()). + WithDatabase(database). WithBandwidth(bandwidth) n := node.NewNode(ctx) diff --git a/context/context.go b/context/context.go index 19fe893..c6d2177 100644 --- a/context/context.go +++ b/context/context.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" hubtypes "github.com/sentinel-official/hub/types" tmlog "github.com/tendermint/tendermint/libs/log" + "gorm.io/gorm" "github.com/sentinel-official/dvpn-node/lite" "github.com/sentinel-official/dvpn-node/types" @@ -18,9 +19,9 @@ type Context struct { bandwidth *hubtypes.Bandwidth client *lite.Client config *types.Config + database *gorm.DB location *types.GeoIPLocation router *mux.Router - sessions *types.Sessions } func NewContext() *Context { @@ -34,7 +35,7 @@ func (c *Context) WithLocation(v *types.GeoIPLocation) *Context { c.location = v func (c *Context) WithLogger(v tmlog.Logger) *Context { c.logger = v; return c } func (c *Context) WithRouter(v *mux.Router) *Context { c.router = v; return c } func (c *Context) WithService(v types.Service) *Context { c.service = v; return c } -func (c *Context) WithSessions(v *types.Sessions) *Context { c.sessions = v; return c } +func (c *Context) WithDatabase(v *gorm.DB) *Context { c.database = v; return c } func (c *Context) Address() hubtypes.NodeAddress { return c.Operator().Bytes() } func (c *Context) Bandwidth() *hubtypes.Bandwidth { return c.bandwidth } @@ -50,7 +51,7 @@ func (c *Context) Operator() sdk.AccAddress { return c.client.FromAdd func (c *Context) RemoteURL() string { return c.Config().Node.RemoteURL } func (c *Context) Router() *mux.Router { return c.router } func (c *Context) Service() types.Service { return c.service } -func (c *Context) Sessions() *types.Sessions { return c.sessions } +func (c *Context) Database() *gorm.DB { return c.database } func (c *Context) IntervalUpdateSessions() time.Duration { return c.Config().Node.IntervalUpdateSessions diff --git a/context/service.go b/context/service.go index 7dd32a8..12fb0bc 100644 --- a/context/service.go +++ b/context/service.go @@ -2,8 +2,6 @@ package context import ( "encoding/base64" - - sdk "github.com/cosmos/cosmos-sdk/types" ) func (c *Context) RemovePeer(key string) error { @@ -22,23 +20,3 @@ func (c *Context) RemovePeer(key string) error { return nil } - -func (c *Context) RemoveSession(key string, address sdk.AccAddress) error { - c.Log().Info("Removing session from list", "key", key, "address", address) - - c.Sessions().DeleteByKey(key) - c.Sessions().DeleteByAddress(address) - - return nil -} - -func (c *Context) RemovePeerAndSession(key string, address sdk.AccAddress) error { - if err := c.RemovePeer(key); err != nil { - return err - } - if err := c.RemoveSession(key, address); err != nil { - return err - } - - return nil -} diff --git a/context/tx.go b/context/tx.go index 9cc954f..42b10d4 100644 --- a/context/tx.go +++ b/context/tx.go @@ -1,8 +1,6 @@ package context import ( - "time" - sdk "github.com/cosmos/cosmos-sdk/types" hubtypes "github.com/sentinel-official/hub/types" nodetypes "github.com/sentinel-official/hub/x/node/types" @@ -69,14 +67,14 @@ func (c *Context) UpdateNodeStatus() error { func (c *Context) UpdateSessions(items ...types.Session) error { c.Log().Info("Updating sessions...") - var messages []sdk.Msg + messages := make([]sdk.Msg, 0, len(items)) for _, item := range items { messages = append(messages, sessiontypes.NewMsgUpdateRequest( c.Address(), sessiontypes.Proof{ Id: item.ID, - Duration: time.Since(item.ConnectedAt), + Duration: item.UpdatedAt.Sub(item.CreatedAt), Bandwidth: hubtypes.NewBandwidthFromInt64(item.Download, item.Upload), }, nil, diff --git a/go.mod b/go.mod index ea4d338..f9c36e6 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,8 @@ require ( github.com/tendermint/tendermint v0.34.11 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e google.golang.org/grpc v1.38.0 + gorm.io/driver/sqlite v1.1.4 + gorm.io/gorm v1.21.11 ) replace ( diff --git a/go.sum b/go.sum index 5960b4d..5bac532 100644 --- a/go.sum +++ b/go.sum @@ -367,6 +367,11 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= +github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmhodges/levigo v1.0.0 h1:q5EC36kV79HWeTBWsod3mG11EgStG3qArTKcvlksN1U= github.com/jmhodges/levigo v1.0.0/go.mod h1:Q6Qx+uH3RAqyK4rFQroq9RL7mdkABMcfhEI+nNuzMJQ= @@ -413,6 +418,8 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-sqlite3 v1.14.5 h1:1IdxlwTNazvbKJQSxoJ5/9ECbEeaTTyeU7sEAZ5KKTQ= +github.com/mattn/go-sqlite3 v1.14.5/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -1085,6 +1092,11 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.1.4 h1:PDzwYE+sI6De2+mxAneV9Xs11+ZyKV6oxD3wDGkaNvM= +gorm.io/driver/sqlite v1.1.4/go.mod h1:mJCeTFr7+crvS+TRnWc5Z3UvwxUN1BGBLMrf5LA9DYw= +gorm.io/gorm v1.20.7/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= +gorm.io/gorm v1.21.11 h1:CxkXW6Cc+VIBlL8yJEHq+Co4RYXdSLiMKNvgoZPjLK4= +gorm.io/gorm v1.21.11/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/node/jobs.go b/node/jobs.go index 950efcd..4b7cf52 100644 --- a/node/jobs.go +++ b/node/jobs.go @@ -16,15 +16,21 @@ func (n *Node) jobSetSessions() error { for ; ; <-t.C { peers, err := n.Service().Peers() if err != nil { - n.Log().Error("Failed to get connected peers", "error", err) return err } - n.Log().Info("Connected peers", "count", len(peers)) for i := 0; i < len(peers); i++ { - item := n.Sessions().GetByKey(peers[i].Key) - if item.Empty() { - n.Log().Error("Unknown connected peer", "peer", peers[i]) + var item types.Session + n.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + Key: peers[i].Key, + }, + ).First(&item) + + if item.ID == 0 { + n.Log().Info("Unknown connected peer", "key", peers[i].Key) if err := n.RemovePeer(peers[i].Key); err != nil { return err } @@ -32,14 +38,26 @@ func (n *Node) jobSetSessions() error { continue } - item.Upload = peers[i].Upload - item.Download = peers[i].Download - n.Sessions().Update(item) - - consumed := sdk.NewInt(item.Upload + item.Download) - if consumed.GT(item.Available) { - n.Log().Info("Peer quota exceeded", "id", item.ID, - "available", item.Available, "consumed", consumed) + n.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + ID: item.ID, + }, + ).Updates( + &types.Session{ + Upload: peers[i].Upload, + Download: peers[i].Download, + }, + ) + + var ( + available = sdk.NewInt(item.Available) + consumed = sdk.NewInt(peers[i].Upload + peers[i].Download) + ) + + if consumed.GT(available) { + n.Log().Info("Peer quota exceeded", "key", peers[i].Key) if err := n.RemovePeer(item.Key); err != nil { return err } @@ -65,11 +83,9 @@ func (n *Node) jobUpdateSessions() error { t := time.NewTicker(n.IntervalUpdateSessions()) for ; ; <-t.C { var items []types.Session - n.Sessions().Iterate(func(v types.Session) bool { - items = append(items, v) - return false - }) - n.Log().Info("Iterated sessions", "count", len(items)) + n.Database().Model( + &types.Session{}, + ).Find(&items) for i := len(items) - 1; i >= 0; i-- { session, err := n.Client().QuerySession(items[i].ID) @@ -82,39 +98,72 @@ func (n *Node) jobUpdateSessions() error { return err } - remove, skip := func() (bool, bool) { - var ( - nochange = items[i].Download == session.Bandwidth.Upload.Int64() - ) + var ( + removePeer = false + removeSession = false + skipUpdate = false + ) + + if items[i].Download == session.Bandwidth.Upload.Int64() { + skipUpdate = true + if items[i].CreatedAt.Before(session.StatusAt) { + removePeer = true + } - switch { - case nochange && items[i].ConnectedAt.Before(session.StatusAt): - n.Log().Info("Stale peer connection", "id", items[i].ID) - return true, true - case !subscription.Status.Equal(hubtypes.StatusActive): - n.Log().Info("Invalid subscription status", "id", items[i].ID, "nochange", nochange) - return true, nochange || subscription.Status.Equal(hubtypes.StatusInactive) - case !session.Status.Equal(hubtypes.StatusActive): - n.Log().Info("Invalid session status", "id", items[i].ID, "nochange", nochange) - return true, nochange || session.Status.Equal(hubtypes.StatusInactive) - default: - return false, false + n.Log().Info("Stale peer connection", "id", items[i].ID) + } + if !subscription.Status.Equal(hubtypes.StatusActive) { + removePeer = true + if subscription.Status.Equal(hubtypes.StatusInactive) { + removeSession, skipUpdate = true, true + } + + n.Log().Info("Invalid subscription status", "id", items[i].ID) + } + if !session.Status.Equal(hubtypes.StatusActive) { + removePeer = true + if session.Status.Equal(hubtypes.StatusInactive) { + removeSession, skipUpdate = true, true } - }() - if remove { - if err := n.RemovePeerAndSession(items[i].Key, items[i].Address); err != nil { + n.Log().Info("Invalid session status", "id", items[i].ID) + } + + if removePeer { + if err := n.RemovePeer(items[i].Key); err != nil { return err } } - if skip { + + if removeSession { + n.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + ID: items[i].ID, + }, + ).Update( + "address", "", + ) + } + + if skipUpdate { items = append(items[:i], items[i+1:]...) } } + n.Database().Model( + &types.Session{}, + ).Where( + "address = ?", "", + ).Delete( + &types.Session{}, + ) + if len(items) == 0 { continue } + if err := n.UpdateSessions(items...); err != nil { return err } diff --git a/rest/session/handlers.go b/rest/session/handlers.go index e356605..8af2c0a 100644 --- a/rest/session/handlers.go +++ b/rest/session/handlers.go @@ -6,7 +6,6 @@ import ( "net" "net/http" "strconv" - "time" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/gorilla/mux" @@ -53,130 +52,208 @@ func handlerAddSession(ctx *context.Context) http.HandlerFunc { return } if account == nil { - utils.WriteErrorToResponse(w, http.StatusNotFound, 2, "account does not exist") + err := fmt.Errorf("account %s does not exist", address) + utils.WriteErrorToResponse(w, http.StatusNotFound, 2, err.Error()) return } if account.GetPubKey() == nil { - utils.WriteErrorToResponse(w, http.StatusNotFound, 2, "public key does not exist") + err := fmt.Errorf("public key for %s does not exist", address) + utils.WriteErrorToResponse(w, http.StatusNotFound, 2, err.Error()) return } if ok := account.GetPubKey().VerifySignature(sdk.Uint64ToBigEndian(id), signature); !ok { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 2, "failed to verify the signature") + err := fmt.Errorf("failed to verify the signature %s", signature) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 2, err.Error()) return } - item := ctx.Sessions().GetByAddress(address) - if item.Empty() { - item = ctx.Sessions().GetByKey(body.Key) - } - - if !item.Empty() { - session, err := ctx.Client().QuerySession(item.ID) - if err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 3, err.Error()) - return - } - if session == nil { - utils.WriteErrorToResponse(w, http.StatusNotFound, 3, "session does not exist") - return - } - if session.Status.Equal(hubtypes.StatusActive) { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 3, fmt.Sprintf("invalid session status %s", session.Status)) - return - } - - if err := ctx.RemovePeerAndSession(item.Key, item.Address); err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 3, err.Error()) - return - } - - if session.Status.Equal(hubtypes.StatusInactivePending) { - go func() { - _ = ctx.UpdateSessions(item) - }() - } - } - session, err := ctx.Client().QuerySession(id) if err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 4, err.Error()) + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 3, err.Error()) return } if session == nil { - utils.WriteErrorToResponse(w, http.StatusNotFound, 4, "session does not exist") + err := fmt.Errorf("session %d does not exist", id) + utils.WriteErrorToResponse(w, http.StatusNotFound, 3, err.Error()) return } if !session.Status.Equal(hubtypes.StatusActive) { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 4, fmt.Sprintf("invalid session status %s", session.Status)) + err := fmt.Errorf("invalid status %s for session %d", session.Status, session.Id) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 3, err.Error()) return } if session.Address != address.String() { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 4, "account address mismatch") + err := fmt.Errorf("account address mismatch; expected %s, got %s", address, session.Address) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 3, err.Error()) return } subscription, err := ctx.Client().QuerySubscription(session.Subscription) if err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 5, err.Error()) + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 4, err.Error()) return } if subscription == nil { - utils.WriteErrorToResponse(w, http.StatusNotFound, 5, "subscription does not exist") + err := fmt.Errorf("subscription %d does not exist", session.Subscription) + utils.WriteErrorToResponse(w, http.StatusNotFound, 4, err.Error()) return } if !subscription.Status.Equal(hubtypes.Active) { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 5, fmt.Sprintf("invalid subscription status %s", subscription.Status)) + err := fmt.Errorf("invalid status %s for subscription %d", subscription.Status, subscription.Id) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 4, err.Error()) return } if subscription.Plan == 0 { if subscription.Node != ctx.Address().String() { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 6, "node address mismatch") + err := fmt.Errorf("node address mismatch; expected %s, got %s", ctx.Address(), subscription.Node) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 5, err.Error()) return } } else { - ok, err := ctx.Client().HasNodeForPlan(id, ctx.Address()) + ok, err := ctx.Client().HasNodeForPlan(subscription.Plan, ctx.Address()) if err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 6, err.Error()) + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 5, err.Error()) return } if !ok { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 6, "node does not exist for plan") + err := fmt.Errorf("node %s does not exist for plan %d", ctx.Address(), id) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 5, err.Error()) return } } + var item types.Session + ctx.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + ID: id, + }, + ).First(&item) + + if item.ID != 0 { + err := fmt.Errorf("peer for session %d already exist", id) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 6, err.Error()) + return + } + + item = types.Session{} + ctx.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + Key: body.Key, + }, + ).First(&item) + + if item.ID != 0 { + err := fmt.Errorf("key %s for service already exist", body.Key) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 6, err.Error()) + return + } + + var items []types.Session + ctx.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + Subscription: subscription.Id, + Address: address.String(), + }, + ).Find(&items) + + for i := 0; i < len(items); i++ { + session, err := ctx.Client().QuerySession(items[i].ID) + if err != nil { + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 7, err.Error()) + return + } + if session == nil { + err := fmt.Errorf("session %d does not exist", items[i].ID) + utils.WriteErrorToResponse(w, http.StatusNotFound, 7, err.Error()) + return + } + if session.Status.Equal(hubtypes.StatusActive) { + err := fmt.Errorf("invalid status %s for session %d", session.Status, session.Id) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 7, err.Error()) + return + } + + if err := ctx.RemovePeer(items[i].Key); err != nil { + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 8, err.Error()) + return + } + + if session.Status.Equal(hubtypes.StatusInactive) { + ctx.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + ID: items[i].ID, + }, + ).Update( + "address", "", + ) + } + } + quota, err := ctx.Client().QueryQuota(subscription.Id, address) if err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 7, err.Error()) + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 9, err.Error()) return } if quota == nil { - utils.WriteErrorToResponse(w, http.StatusNotFound, 7, "quota does not exist") + err := fmt.Errorf("quota for address %s does not exist", address) + utils.WriteErrorToResponse(w, http.StatusNotFound, 9, err.Error()) return } + + items = []types.Session{} + ctx.Database().Model( + &types.Session{}, + ).Where( + &types.Session{ + Subscription: subscription.Id, + Address: address.String(), + }, + ).Find(&items) + + for i := 0; i < len(items); i++ { + consumed := items[i].Download + items[i].Upload + quota.Consumed = quota.Consumed.Add( + hubtypes.NewBandwidthFromInt64( + consumed, 0, + ).CeilTo( + hubtypes.Gigabyte.Quo(subscription.Price.Amount), + ).Sum(), + ) + } + if quota.Consumed.GTE(quota.Allocated) { - utils.WriteErrorToResponse(w, http.StatusBadRequest, 7, "quota exceeded") + err := fmt.Errorf("quota exceeded; allocated %s, consumed %s", quota.Allocated, quota.Consumed) + utils.WriteErrorToResponse(w, http.StatusBadRequest, 10, err.Error()) return } result, err := ctx.Service().AddPeer(key) if err != nil { - utils.WriteErrorToResponse(w, http.StatusInternalServerError, 8, err.Error()) + utils.WriteErrorToResponse(w, http.StatusInternalServerError, 11, err.Error()) return } ctx.Log().Info("Added a new peer", "key", body.Key, "count", ctx.Service().PeersCount()) - ctx.Sessions().Set( - types.Session{ - ID: id, - Key: body.Key, - Address: address, - Available: quota.Allocated.Sub(quota.Consumed), - ConnectedAt: time.Now(), + ctx.Database().Model( + &types.Session{}, + ).Create( + &types.Session{ + ID: id, + Subscription: subscription.Id, + Key: body.Key, + Address: address.String(), + Available: quota.Allocated.Sub(quota.Consumed).Int64(), }, ) - ctx.Log().Info("Added a new session", "id", id, "address", address, "count", ctx.Sessions().Len()) result = append(result, net.ParseIP(ctx.Location().IP).To4()...) result = append(result, ctx.Service().Info()...) diff --git a/types/keys.go b/types/keys.go index 3b6140e..5c6856e 100644 --- a/types/keys.go +++ b/types/keys.go @@ -6,11 +6,12 @@ import ( ) const ( - ConfigFileName = "config.toml" - FlagForce = "force" - KeyringName = "sentinel" - DefaultIPv4CIDR = "10.8.0.2/24" - DefaultIPv6CIDR = "fd86:ea04:1115::2/120" + ConfigFileName = "config.toml" + DatabaseFileName = "data.db" + FlagForce = "force" + KeyringName = "sentinel" + DefaultIPv4CIDR = "10.8.0.2/24" + DefaultIPv6CIDR = "fd86:ea04:1115::2/120" ) var ( diff --git a/types/session.go b/types/session.go index 5389486..7f9f0f8 100644 --- a/types/session.go +++ b/types/session.go @@ -1,172 +1,30 @@ package types import ( - "fmt" - "reflect" - "sync" - "time" - sdk "github.com/cosmos/cosmos-sdk/types" + "gorm.io/gorm" ) -func withTypePrefix(v interface{}) string { - t := reflect.TypeOf(v).String() - switch v := v.(type) { - case string: - return t + v - case fmt.Stringer: - return t + v.String() - default: - return "" - } -} - type Session struct { - ID uint64 `json:"id,omitempty"` - Key string `json:"key,omitempty"` - Address sdk.AccAddress `json:"address,omitempty"` - Available sdk.Int `json:"available,omitempty"` - Download int64 `json:"download,omitempty"` - Upload int64 `json:"upload,omitempty"` - ConnectedAt time.Time `json:"connected_at,omitempty"` -} - -func (s Session) Empty() bool { - return s.ID == 0 -} - -type Sessions struct { - sync.RWMutex - m map[string]interface{} + *gorm.Model + ID uint64 `gorm:"primaryKey"` + Subscription uint64 + Key string `gorm:"unique"` + Address string + Available int64 + Download int64 + Upload int64 } -func NewSessions() *Sessions { - return &Sessions{ - m: make(map[string]interface{}), +func (s *Session) GetAddress() sdk.AccAddress { + if s.Address == "" { + return nil } -} - -func (s *Sessions) unsafeIsNil(v Session) bool { - return s.unsafeGetForKey(v.Key).Empty() && - s.unsafeGetForAddress(v.Address).Empty() -} - -func (s *Sessions) unsafeSet(v Session) { - s.m[withTypePrefix(v.Key)] = v - s.m[withTypePrefix(v.Address)] = v.Key -} - -func (s *Sessions) unsafeDelete(v Session) { - delete(s.m, withTypePrefix(v.Key)) - delete(s.m, withTypePrefix(v.Address)) -} - -func (s *Sessions) unsafeGetForKey(k string) (x Session) { - v, ok := s.m[withTypePrefix(k)] - if !ok { - return x - } - - return v.(Session) -} - -func (s *Sessions) unsafeGetForAddress(k sdk.AccAddress) (x Session) { - v, ok := s.m[withTypePrefix(k)] - if !ok { - return x - } - - v, ok = s.m[withTypePrefix(v.(string))] - if !ok { - return x - } - - return v.(Session) -} - -func (s *Sessions) Set(v Session) { - s.Lock() - defer s.Unlock() - - s.unsafeSet(v) -} - -func (s *Sessions) Update(v Session) { - s.Lock() - defer s.Unlock() - if s.unsafeIsNil(v) { - return + address, err := sdk.AccAddressFromBech32(s.Address) + if err != nil { + panic(err) } - s.unsafeSet(v) -} - -func (s *Sessions) GetByKey(k string) Session { - s.RLock() - defer s.RUnlock() - - return s.unsafeGetForKey(k) -} - -func (s *Sessions) GetByAddress(k sdk.AccAddress) Session { - s.RLock() - defer s.RUnlock() - - return s.unsafeGetForAddress(k) -} - -func (s *Sessions) DeleteByKey(k string) { - s.Lock() - defer s.Unlock() - - if k == "" { - return - } - - v := s.unsafeGetForKey(k) - if v.Empty() { - return - } - - s.unsafeDelete(v) -} - -func (s *Sessions) DeleteByAddress(k sdk.AccAddress) { - s.Lock() - defer s.Unlock() - - if k == nil || k.Empty() { - return - } - - v := s.unsafeGetForAddress(k) - if v.Empty() { - return - } - - s.unsafeDelete(v) -} - -func (s *Sessions) Len() int { - s.RLock() - defer s.RUnlock() - - return len(s.m) / 2 -} - -func (s *Sessions) Iterate(fn func(v Session) bool) { - s.RLock() - defer s.RUnlock() - - for _, v := range s.m { - v, ok := v.(Session) - if !ok { - continue - } - - if fn(v) { - break - } - } + return address }