Skip to content

Commit

Permalink
Merge pull request #767 from sushanth0910/add-logs-metrics-dims
Browse files Browse the repository at this point in the history
add logs and metrics to find sts call success/failures on global/regional endpoints
  • Loading branch information
k8s-ci-robot authored Oct 21, 2024
2 parents d018f8c + f1349dd commit 29c47ac
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 16 deletions.
14 changes: 7 additions & 7 deletions pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ type Metrics struct {
ConfigMapWatchFailures prometheus.Counter
Latency *prometheus.HistogramVec
EC2DescribeInstanceCallCount prometheus.Counter
StsConnectionFailure prometheus.Counter
StsConnectionFailure *prometheus.CounterVec
StsResponses *prometheus.CounterVec
DynamicFileFailures prometheus.Counter
StsThrottling prometheus.Counter
StsThrottling *prometheus.CounterVec
E2ELatency *prometheus.HistogramVec
DynamicFileEnabled prometheus.Gauge
DynamicFileOnly prometheus.Gauge
Expand All @@ -65,26 +65,26 @@ func createMetrics(reg prometheus.Registerer) Metrics {
Help: "Dynamic file failures",
},
),
StsConnectionFailure: factory.NewCounter(
StsConnectionFailure: factory.NewCounterVec(
prometheus.CounterOpts{
Namespace: Namespace,
Name: "sts_connection_failures_total",
Help: "Sts call could not succeed or timedout",
},
}, []string{"StsRegion"},
),
StsThrottling: factory.NewCounter(
StsThrottling: factory.NewCounterVec(
prometheus.CounterOpts{
Namespace: Namespace,
Name: "sts_throttling_total",
Help: "Sts call got throttled",
},
}, []string{"StsRegion"},
),
StsResponses: factory.NewCounterVec(
prometheus.CounterOpts{
Namespace: Namespace,
Name: "sts_responses_total",
Help: "Sts responses with error code label",
}, []string{"ResponseCode"},
}, []string{"ResponseCode", "StsRegion"},
),
Latency: factory.NewHistogramVec(
prometheus.HistogramOpts{
Expand Down
8 changes: 5 additions & 3 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
"accountid": identity.AccountID,
"userid": identity.UserID,
"session": identity.SessionName,
"stsendpoint": identity.STSEndpoint,
}).Info("STS response")

// look up the ARN in each of our mappings to fill in the username and groups
Expand All @@ -380,9 +381,10 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)

// the token is valid and the role is mapped, return success!
log.WithFields(logrus.Fields{
"username": username,
"uid": uid,
"groups": groups,
"username": username,
"uid": uid,
"groups": groups,
"stsendpoint": identity.STSEndpoint,
}).Info("access granted")
metrics.Get().Latency.WithLabelValues(metrics.Success).Observe(duration(start))
w.WriteHeader(http.StatusOK)
Expand Down
37 changes: 31 additions & 6 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ type Identity struct {
// in conjunction with CloudTrail to determine the identity of the individual
// if the individual assumed an IAM role before making the request.
AccessKeyID string

// ASW STS endpoint used to authenticate (expected values is sts endpoint eg: sts.us-west-2.amazonaws.com)
STSEndpoint string
}

const (
Expand Down Expand Up @@ -503,6 +506,11 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
return nil, err
}

stsRegion, err := getStsRegion(parsedURL.Host)
if err != nil {
return nil, err
}

if parsedURL.Path != "/" {
return nil, FormatError{"unexpected path in pre-signed URL"}
}
Expand Down Expand Up @@ -567,12 +575,12 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {

response, err := v.client.Do(req)
if err != nil {
metrics.Get().StsConnectionFailure.Inc()
metrics.Get().StsConnectionFailure.WithLabelValues(stsRegion).Inc()
// special case to avoid printing the full URL if possible
if urlErr, ok := err.(*url.Error); ok {
return nil, NewSTSError(fmt.Sprintf("error during GET: %v", urlErr.Err))
return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsRegion))
}
return nil, NewSTSError(fmt.Sprintf("error during GET: %v", err))
return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsRegion))
}
defer response.Body.Close()

Expand All @@ -581,16 +589,16 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err))
}

metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc()
metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsRegion).Inc()
if response.StatusCode != 200 {
responseStr := string(responseBody[:])
// refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log
// response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"}
if strings.Contains(responseStr, "Throttling") {
metrics.Get().StsThrottling.Inc()
metrics.Get().StsThrottling.WithLabelValues(stsRegion).Inc()
return nil, NewSTSThrottling(responseStr)
}
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr))
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsRegion, responseStr))
}

var callerIdentity getCallerIdentityWrapper
Expand All @@ -601,6 +609,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {

id := &Identity{
AccessKeyID: accessKeyID,
STSEndpoint: parsedURL.Host,
}
return getIdentityFromSTSResponse(id, callerIdentity)
}
Expand Down Expand Up @@ -660,3 +669,19 @@ func hasSignedClusterIDHeader(paramsLower *url.Values) bool {
}
return false
}

func getStsRegion(host string) (string, error) {
if host == "" {
return "", fmt.Errorf("host is empty")
}

parts := strings.Split(host, ".")
if len(parts) < 3 {
return "", fmt.Errorf("invalid host format: %v", host)
}

if host == "sts.amazonaws.com" {
return "global", nil
}
return parts[1], nil
}
25 changes: 25 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,28 @@ func TestGetWithSTS(t *testing.T) {
})
}
}

func TestGetStsRegion(t *testing.T) {
tests := []struct {
host string
expected string
wantErr bool
}{
{"sts.amazonaws.com", "global", false}, // Global endpoint
{"sts.us-west-2.amazonaws.com", "us-west-2", false}, // Valid regional endpoint
{"sts.eu-central-1.amazonaws.com", "eu-central-1", false}, // Another valid regional endpoint
{"", "", true}, // Empty input (expect error)
{"sts", "", true}, // Malformed input (expect error)
{"sts.wrongformat", "", true}, // Malformed input (expect error)
}

for _, test := range tests {
result, err := getStsRegion(test.host)
if (err != nil) != test.wantErr {
t.Errorf("getStsRegion(%q) error = %v, wantErr %v", test.host, err, test.wantErr)
}
if result != test.expected {
t.Errorf("getStsRegion(%q) = %q; expected %q", test.host, result, test.expected)
}
}
}

0 comments on commit 29c47ac

Please sign in to comment.