Skip to content

Commit

Permalink
refactor: move check domains implementation to pkg
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Gateru <[email protected]>
  • Loading branch information
felixgateru committed Dec 18, 2024
1 parent d60a1bd commit 8b4a49d
Show file tree
Hide file tree
Showing 29 changed files with 428 additions and 562 deletions.
53 changes: 6 additions & 47 deletions channels/middleware/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,16 @@ var (
var _ channels.Service = (*authorizationMiddleware)(nil)

type authorizationMiddleware struct {
svc channels.Service
repo channels.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
domains authz.DomainCheck
svc channels.Service
repo channels.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
rmMW.RoleManagerAuthorizationMiddleware
}

// AuthorizationMiddleware adds authorization to the channels service.
func AuthorizationMiddleware(svc channels.Service, repo channels.Repository, authz smqauthz.Authorization, domains authz.DomainCheck, channelsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission, extOpPerm map[svcutil.ExternalOperation]svcutil.Permission) (channels.Service, error) {
func AuthorizationMiddleware(svc channels.Service, repo channels.Repository, authz smqauthz.Authorization, channelsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission, extOpPerm map[svcutil.ExternalOperation]svcutil.Permission) (channels.Service, error) {
opp := channels.NewOperationPerm()
if err := opp.AddOperationPermissionMap(channelsOpPerm); err != nil {
return nil, err
Expand Down Expand Up @@ -78,14 +77,10 @@ func AuthorizationMiddleware(svc channels.Service, repo channels.Repository, aut
RoleManagerAuthorizationMiddleware: ram,
opp: opp,
extOpp: extOpp,
domains: domains,
}, nil
}

func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return []channels.Channel{}, err
}
if err := am.extAuthorize(ctx, channels.DomainOpCreateChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand Down Expand Up @@ -113,9 +108,6 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a
}

func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Channel{}, err
}
if err := am.authorize(ctx, channels.OpViewChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -129,26 +121,17 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth
}

func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Page{}, err
}
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
session.SuperAdmin = true
}
return am.svc.ListChannels(ctx, session, pm)
}

func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Page{}, err
}
return am.svc.ListChannelsByClient(ctx, session, clientID, pm)
}

func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Channel{}, err
}
if err := am.authorize(ctx, channels.OpUpdateChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -162,9 +145,6 @@ func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session au
}

func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Channel{}, err
}
if err := am.authorize(ctx, channels.OpUpdateChannelTags, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -178,9 +158,6 @@ func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, sessio
}

func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Channel{}, err
}
if err := am.authorize(ctx, channels.OpEnableChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -194,9 +171,6 @@ func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session au
}

func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return channels.Channel{}, err
}
if err := am.authorize(ctx, channels.OpDisableChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -210,9 +184,6 @@ func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session a
}

func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
if err := am.authorize(ctx, channels.OpDeleteChannel, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -226,9 +197,6 @@ func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session au
}

func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
for _, chID := range chIDs {
if err := am.authorize(ctx, channels.OpConnectClient, authz.PolicyReq{
Domain: session.DomainID,
Expand Down Expand Up @@ -256,9 +224,6 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se
}

func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
for _, chID := range chIDs {
if err := am.authorize(ctx, channels.OpDisconnectClient, authz.PolicyReq{
Domain: session.DomainID,
Expand Down Expand Up @@ -286,9 +251,6 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn
}

func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
if err := am.authorize(ctx, channels.OpSetParentGroup, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -312,9 +274,6 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
}

func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
if err := am.authorize(ctx, channels.OpSetParentGroup, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand Down
47 changes: 6 additions & 41 deletions clients/middleware/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,16 @@ var (
var _ clients.Service = (*authorizationMiddleware)(nil)

type authorizationMiddleware struct {
svc clients.Service
repo clients.Repository
authz authz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
domains authz.DomainCheck
svc clients.Service
repo clients.Repository
authz authz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
rmMW.RoleManagerAuthorizationMiddleware
}

// AuthorizationMiddleware adds authorization to the clients service.
func AuthorizationMiddleware(entityType string, svc clients.Service, authz authz.Authorization, repo clients.Repository, domains authz.DomainCheck, clientsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission, extOpPerm map[svcutil.ExternalOperation]svcutil.Permission) (clients.Service, error) {
func AuthorizationMiddleware(entityType string, svc clients.Service, authz authz.Authorization, repo clients.Repository, clientsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission, extOpPerm map[svcutil.ExternalOperation]svcutil.Permission) (clients.Service, error) {
opp := clients.NewOperationPerm()
if err := opp.AddOperationPermissionMap(clientsOpPerm); err != nil {
return nil, err
Expand All @@ -69,15 +68,11 @@ func AuthorizationMiddleware(entityType string, svc clients.Service, authz authz
repo: repo,
opp: opp,
extOpp: extOpp,
domains: domains,
RoleManagerAuthorizationMiddleware: ram,
}, nil
}

func (am *authorizationMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return []clients.Client{}, err
}
if err := am.extAuthorize(ctx, clients.DomainOpCreateClient, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -92,9 +87,6 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au
}

func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.Client{}, err
}
if err := am.authorize(ctx, clients.OpViewClient, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -111,17 +103,11 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
session.SuperAdmin = true
}
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.ClientsPage{}, err
}

return am.svc.ListClients(ctx, session, reqUserID, pm)
}

func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.Client{}, err
}
if err := am.authorize(ctx, clients.OpUpdateClient, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -136,9 +122,6 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses
}

func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.Client{}, err
}
if err := am.authorize(ctx, clients.OpUpdateClientTags, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -153,9 +136,6 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn
}

func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.Client{}, err
}
if err := am.authorize(ctx, clients.OpUpdateClientSecret, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -169,9 +149,6 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut
}

func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.Client{}, err
}
if err := am.authorize(ctx, clients.OpEnableClient, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -186,9 +163,6 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses
}

func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return clients.Client{}, err
}
if err := am.authorize(ctx, clients.OpDisableClient, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -202,9 +176,6 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se
}

func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
if err := am.authorize(ctx, clients.OpDeleteClient, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -219,9 +190,6 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses
}

func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
if err := am.authorize(ctx, clients.OpSetParentGroup, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand All @@ -245,9 +213,6 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
}

func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.domains.CheckDomain(ctx, session); err != nil {
return err
}
if err := am.authorize(ctx, clients.OpRemoveParentGroup, authz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Expand Down
30 changes: 23 additions & 7 deletions cmd/bootstrap/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/events"
"github.com/absmach/supermq/pkg/events/store"
"github.com/absmach/supermq/pkg/grpcclient"
Expand All @@ -48,12 +49,13 @@ import (
)

const (
svcName = "bootstrap"
envPrefixDB = "SMQ_BOOTSTRAP_DB_"
envPrefixHTTP = "SMQ_BOOTSTRAP_HTTP_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
defDB = "bootstrap"
defSvcHTTPPort = "9013"
svcName = "bootstrap"
envPrefixDB = "SMQ_BOOTSTRAP_DB_"
envPrefixHTTP = "SMQ_BOOTSTRAP_HTTP_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
envPrefixDomains = "SMQ_DOMAINS_GRPC_"
defDB = "bootstrap"
defSvcHTTPPort = "9013"

stream = "events.supermq.clients"
streamID = "supermq.bootstrap"
Expand Down Expand Up @@ -148,7 +150,21 @@ func main() {
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
defer authnClient.Close()

authz, authzClient, err := authsvcAuthz.NewAuthorization(ctx, grpcCfg)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
exitCode = 1
return
}
domainsAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer domainsHandler.Close()

authz, authzClient, err := authsvcAuthz.NewAuthorization(ctx, grpcCfg, domainsAuthz)
if err != nil {
logger.Error(err.Error())
exitCode = 1
Expand Down
Loading

0 comments on commit 8b4a49d

Please sign in to comment.