Skip to content

Commit

Permalink
refactor: several more consts usages (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
james-d-elliott authored Dec 21, 2023
1 parent bf342a7 commit fe50089
Show file tree
Hide file tree
Showing 81 changed files with 1,113 additions and 1,078 deletions.
16 changes: 9 additions & 7 deletions access_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@ import (
"encoding/json"
"fmt"
"net/http"

"authelia.com/provider/oauth2/internal/consts"
)

func (f *Fosite) WriteAccessError(ctx context.Context, rw http.ResponseWriter, req AccessRequester, err error) {
f.writeJsonError(ctx, rw, req, err)
}

func (f *Fosite) writeJsonError(ctx context.Context, rw http.ResponseWriter, requester AccessRequester, err error) {
rw.Header().Set("Content-Type", "application/json;charset=UTF-8")
rw.Header().Set("Cache-Control", "no-store")
rw.Header().Set("Pragma", "no-cache")
rw.Header().Set(consts.HeaderContentType, consts.ContentTypeApplicationJSON)
rw.Header().Set(consts.HeaderCacheControl, consts.CacheControlNoStore)
rw.Header().Set(consts.HeaderPragma, consts.PragmaNoCache)

rfcerr := ErrorToRFC6749Error(err).WithLegacyFormat(f.Config.GetUseLegacyErrorFormat(ctx)).WithExposeDebug(f.Config.GetSendDebugMessagesToClients(ctx))
rfc := ErrorToRFC6749Error(err).WithLegacyFormat(f.Config.GetUseLegacyErrorFormat(ctx)).WithExposeDebug(f.Config.GetSendDebugMessagesToClients(ctx))

if requester != nil {
rfcerr = rfcerr.WithLocalizer(f.Config.GetMessageCatalog(ctx), getLangFromRequester(requester))
rfc = rfc.WithLocalizer(f.Config.GetMessageCatalog(ctx), getLangFromRequester(requester))
}

js, err := json.Marshal(rfcerr)
js, err := json.Marshal(rfc)
if err != nil {
if f.Config.GetSendDebugMessagesToClients(ctx) {
errorMessage := EscapeJSONString(err.Error())
Expand All @@ -36,7 +38,7 @@ func (f *Fosite) writeJsonError(ctx context.Context, rw http.ResponseWriter, req
return
}

rw.WriteHeader(rfcerr.CodeField)
rw.WriteHeader(rfc.CodeField)
// ignoring the error because the connection is broken when it happens
_, _ = rw.Write(js)
}
2 changes: 1 addition & 1 deletion access_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import (
"net/http"
"strings"

"authelia.com/provider/oauth2/internal/consts"
"github.com/pkg/errors"

"authelia.com/provider/oauth2/i18n"
"authelia.com/provider/oauth2/internal/consts"
"authelia.com/provider/oauth2/internal/errorsx"
)

Expand Down
65 changes: 33 additions & 32 deletions access_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

. "authelia.com/provider/oauth2"
"authelia.com/provider/oauth2/internal"
"authelia.com/provider/oauth2/internal/consts"
)

func TestNewAccessRequest(t *testing.T) {
Expand Down Expand Up @@ -54,7 +55,7 @@ func TestNewAccessRequest(t *testing.T) {
header: http.Header{},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {},
expectErr: ErrInvalidRequest,
Expand All @@ -63,19 +64,19 @@ func TestNewAccessRequest(t *testing.T) {
header: http.Header{},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
"client_id": {""},
consts.FormParameterGrantType: {"foo"},
consts.FormParameterClientID: {""},
},
expectErr: ErrInvalidRequest,
mock: func() {},
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidClient,
mock: func() {
Expand All @@ -85,22 +86,22 @@ func TestNewAccessRequest(t *testing.T) {
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "GET",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidRequest,
mock: func() {},
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidClient,
mock: func() {
Expand All @@ -110,11 +111,11 @@ func TestNewAccessRequest(t *testing.T) {
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidClient,
mock: func() {
Expand All @@ -127,11 +128,11 @@ func TestNewAccessRequest(t *testing.T) {
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrServerError,
mock: func() {
Expand All @@ -145,11 +146,11 @@ func TestNewAccessRequest(t *testing.T) {
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
Expand All @@ -168,11 +169,11 @@ func TestNewAccessRequest(t *testing.T) {
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
Expand All @@ -195,10 +196,11 @@ func TestNewAccessRequest(t *testing.T) {
Form: c.form,
Method: c.method,
}

c.mock()
ctx := NewContext()
config.TokenEndpointHandlers = c.handlers
ar, err := provider.NewAccessRequest(ctx, r, new(DefaultSession))

ar, err := provider.NewAccessRequest(context.TODO(), r, new(DefaultSession))

if c.expectErr != nil {
assert.EqualError(t, err, c.expectErr.Error())
Expand Down Expand Up @@ -245,7 +247,7 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
// No registered handlers -> error
{
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0)
Expand All @@ -257,10 +259,10 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
// Handler can skip client auth and ignores missing client.
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
// despite error from storage, we should success, because client auth is not required
Expand All @@ -279,7 +281,7 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
// Should pass if no auth is set in the header and can skip!
{
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
Expand All @@ -296,10 +298,10 @@ func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
// Should also pass if client auth is set!
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), "foo").Return(anotherClient, nil).Times(1)
Expand Down Expand Up @@ -371,10 +373,10 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
}{
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
Expand All @@ -389,10 +391,10 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
},
{
header: http.Header{
"Authorization": {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
Expand All @@ -414,7 +416,7 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
{
header: http.Header{},
form: url.Values{
"grant_type": {"foo"},
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0)
Expand All @@ -433,9 +435,8 @@ func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
Method: c.method,
}
c.mock()
ctx := NewContext()
config.TokenEndpointHandlers = c.handlers
ar, err := provider.NewAccessRequest(ctx, r, new(DefaultSession))
ar, err := provider.NewAccessRequest(context.TODO(), r, new(DefaultSession))

if c.expectErr != nil {
assert.EqualError(t, err, c.expectErr.Error())
Expand Down
8 changes: 5 additions & 3 deletions access_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@ import (
"context"
"encoding/json"
"net/http"

"authelia.com/provider/oauth2/internal/consts"
)

func (f *Fosite) WriteAccessResponse(ctx context.Context, rw http.ResponseWriter, requester AccessRequester, responder AccessResponder) {
rw.Header().Set("Cache-Control", "no-store")
rw.Header().Set("Pragma", "no-cache")
rw.Header().Set(consts.HeaderCacheControl, consts.CacheControlNoStore)
rw.Header().Set(consts.HeaderPragma, consts.PragmaNoCache)

js, err := json.Marshal(responder.ToMap())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}

rw.Header().Set("Content-Type", "application/json;charset=UTF-8")
rw.Header().Set(consts.HeaderContentType, consts.ContentTypeApplicationJSON)

rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(js)
Expand Down
7 changes: 4 additions & 3 deletions access_write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

. "authelia.com/provider/oauth2"
. "authelia.com/provider/oauth2/internal"
"authelia.com/provider/oauth2/internal/consts"
)

func TestWriteAccessResponse(t *testing.T) {
Expand All @@ -30,7 +31,7 @@ func TestWriteAccessResponse(t *testing.T) {
resp.EXPECT().ToMap().Return(map[string]any{})

provider.WriteAccessResponse(context.Background(), rw, ar, resp)
assert.Equal(t, "application/json;charset=UTF-8", header.Get("Content-Type"))
assert.Equal(t, "no-store", header.Get("Cache-Control"))
assert.Equal(t, "no-cache", header.Get("Pragma"))
assert.Equal(t, consts.ContentTypeApplicationJSON, header.Get(consts.HeaderContentType))
assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl))
assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma))
}
Loading

0 comments on commit fe50089

Please sign in to comment.