Skip to content

Commit

Permalink
KNOX-3023 - Include groups in a header in ConfigurableDispatch (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
moresandeep authored May 1, 2024
1 parent e1a7468 commit b6ff0ac
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ public interface SpiGatewayMessages {

@Message( level = MessageLevel.DEBUG, text = "Malformed dispatch URL: {0}" )
void malformedDispatchUrl(String url);

@Message( level = MessageLevel.ERROR, text = "No valid principal found" )
void noPrincipalFound();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,29 @@
import org.apache.knox.gateway.audit.api.ActionOutcome;
import org.apache.knox.gateway.config.Configure;
import org.apache.knox.gateway.config.Default;
import org.apache.knox.gateway.security.SubjectUtils;
import org.apache.knox.gateway.util.StringUtils;

import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;

import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.Arrays;
import java.util.Optional;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Optional;
import java.util.List;
import java.util.Collection;
import java.util.Locale;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
Expand All @@ -50,6 +59,21 @@ public class ConfigurableDispatch extends DefaultDispatch {
private Set<String> responseExcludeSetCookieHeaderDirectives = super.getOutboundResponseExcludedSetCookieHeaderDirectives();
private Boolean removeUrlEncoding = false;

private boolean shouldIncludePrincipalAndGroups;
private String actorIdHeaderName = DEFAULT_AUTH_ACTOR_ID_HEADER_NAME;
private String actorGroupsHeaderPrefix = DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX;
private String groupFilterPattern = DEFAULT_GROUP_FILTER_PATTERN;

static final String DEFAULT_AUTH_ACTOR_ID_HEADER_NAME = "X-Knox-Actor-ID";
static final String DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX = "X-Knox-Actor-Groups";
static final String DEFAULT_GROUP_FILTER_PATTERN = ".*";
static final String DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED = "false";

protected static final int MAX_HEADER_LENGTH = 1000;
protected static final String ACTOR_GROUPS_HEADER_FORMAT = "%s-%d";
protected Pattern groupPattern = Pattern.compile(DEFAULT_GROUP_FILTER_PATTERN);


private Set<String> convertCommaDelimitedHeadersToSet(String headers) {
return headers == null ? Collections.emptySet(): new HashSet<>(Arrays.asList(headers.split("\\s*,\\s*")));
}
Expand Down Expand Up @@ -123,6 +147,27 @@ protected void setRemoveUrlEncoding(@Default("false") String removeUrlEncoding)
this.removeUrlEncoding = Boolean.parseBoolean(removeUrlEncoding);
}

@Configure
public void setShouldIncludePrincipalAndGroups(@Default(DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED) boolean shouldIncludePrincipalAndGroups) {
this.shouldIncludePrincipalAndGroups = shouldIncludePrincipalAndGroups;
}

@Configure
public void setActorIdHeaderName(@Default(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME) String actorIdHeaderName) {
this.actorIdHeaderName = actorIdHeaderName;
}

@Configure
public void setActorGroupsHeaderPrefix(@Default(DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX) String actorGroupsHeaderPrefix) {
this.actorGroupsHeaderPrefix = actorGroupsHeaderPrefix;
}

@Configure
public void setGroupFilterPattern(@Default(DEFAULT_GROUP_FILTER_PATTERN) String groupFilterPattern) {
this.groupFilterPattern = groupFilterPattern;
groupPattern = Pattern.compile(this.groupFilterPattern);
}

@Override
public void copyRequestHeaderFields(HttpUriRequest outboundRequest,
HttpServletRequest inboundRequest) {
Expand All @@ -133,6 +178,61 @@ public void copyRequestHeaderFields(HttpUriRequest outboundRequest,
if(MapUtils.isNotEmpty(extraHeaders)){
extraHeaders.forEach(outboundRequest::addHeader);
}

/* If we need to add user and groups to outbound request */
if(shouldIncludePrincipalAndGroups) {
Map<String, String> groups = addPrincipalAndGroups();
if(MapUtils.isNotEmpty(groups)){
groups.forEach(outboundRequest::addHeader);
}
}
}

private Map<String, String> addPrincipalAndGroups() {
final Map<String, String> headers = new ConcurrentHashMap();
final Subject subject = SubjectUtils.getCurrentSubject();

final String primaryPrincipalName = subject == null ? null : SubjectUtils.getPrimaryPrincipalName(subject);
if (primaryPrincipalName == null) {
LOG.noPrincipalFound();
headers.put(actorIdHeaderName, "");
} else {
headers.put(actorIdHeaderName, primaryPrincipalName);
}

// Populate actor groups headers
final Set<String> matchingGroupNames = subject == null ? Collections.emptySet()
: SubjectUtils.getGroupPrincipals(subject).stream().filter(group -> groupPattern.matcher(group.getName()).matches()).map(group -> group.getName())
.collect(Collectors.toSet());
if (!matchingGroupNames.isEmpty()) {
final List<String> groupStrings = getGroupStrings(matchingGroupNames);
for (int i = 0; i < groupStrings.size(); i++) {
headers.put(String.format(Locale.ROOT, ACTOR_GROUPS_HEADER_FORMAT, actorGroupsHeaderPrefix, i + 1), groupStrings.get(i));
}
}
return headers;
}

private List<String> getGroupStrings(final Collection<String> groupNames) {
if (groupNames.isEmpty()) {
return Collections.emptyList();
}
List<String> groupStrings = new ArrayList<>();
StringBuilder sb = new StringBuilder();
for (String groupName : groupNames) {
if (sb.length() + groupName.length() > MAX_HEADER_LENGTH) {
groupStrings.add(sb.toString());
sb = new StringBuilder();
}
if (sb.length() > 0) {
sb.append(',');
}
sb.append(groupName);
}
if (sb.length() > 0) {
groupStrings.add(sb.toString());
}
return groupStrings;
}

@Override
Expand Down Expand Up @@ -180,4 +280,5 @@ public URI getDispatchUrl(HttpServletRequest request) {

return super.getDispatchUrl(request);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.knox.gateway.dispatch;

import static org.apache.knox.gateway.dispatch.AbstractGatewayDispatch.REQUEST_ID_HEADER_NAME;
import static org.apache.knox.gateway.dispatch.ConfigurableDispatch.DEFAULT_AUTH_ACTOR_ID_HEADER_NAME;
import static org.apache.knox.gateway.dispatch.DefaultDispatch.SET_COOKIE;
import static org.apache.knox.gateway.dispatch.DefaultDispatch.WWW_AUTHENTICATE;
import static org.hamcrest.CoreMatchers.containsString;
Expand All @@ -26,12 +27,15 @@
import static org.junit.Assert.assertThat;

import java.net.URI;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

Expand All @@ -41,6 +45,8 @@
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.message.BasicHeader;
import org.apache.knox.gateway.security.GroupPrincipal;
import org.apache.knox.gateway.security.PrimaryPrincipal;
import org.apache.knox.test.TestUtils;
import org.apache.knox.test.mock.MockHttpServletResponse;
import org.apache.logging.log4j.CloseableThreadContext;
Expand Down Expand Up @@ -316,7 +322,7 @@ public void testRequestAppendHeadersConfig() {
assertThat(outboundRequestHeaders[3].getName(), is("c"));
}

@Test( timeout = TestUtils.SHORT_TIMEOUT )
@Test( timeout = TestUtils.LONG_TIMEOUT )
public void testRequestExcludeAndAppendHeadersConfig() {
ConfigurableDispatch dispatch = new ConfigurableDispatch();
dispatch.setRequestAppendHeaders("a : b ; c : d");
Expand Down Expand Up @@ -724,4 +730,47 @@ public void testXRequestIDHeaderExcludeListNoReqHeader() {
assertThat(outboundResponse.getHeader(REQUEST_ID_HEADER_NAME), nullValue());
}

/**
* Make sure X-Knox-Actor-ID and X-Knox-Actor-Groups-1 headers
* are added for authenticated users.
*/
@Test
public void testGroupHeaders() throws PrivilegedActionException {
Subject subject = new Subject();
subject.getPrincipals().add(new PrimaryPrincipal("knoxui"));
subject.getPrincipals().add(new GroupPrincipal("knox"));
subject.getPrincipals().add(new GroupPrincipal("admin"));

ConfigurableDispatch dispatch = new ConfigurableDispatch();
final String headerReqID = "1234567890ABCD";
dispatch.setShouldIncludePrincipalAndGroups(true);

Map<String, String> headers = new HashMap<>();
headers.put(REQUEST_ID_HEADER_NAME, headerReqID);
headers.put(HttpHeaders.ACCEPT, "abc");
headers.put("TEST", "test");

HttpServletRequest inboundRequest = EasyMock.createNiceMock(HttpServletRequest.class);
EasyMock.expect(inboundRequest.getHeaderNames()).andReturn(Collections.enumeration(headers.keySet())).anyTimes();
Capture<String> capturedArgument = Capture.newInstance();
EasyMock.expect(inboundRequest.getHeader(EasyMock.capture(capturedArgument)))
.andAnswer(() -> headers.get(capturedArgument.getValue())).anyTimes();
EasyMock.replay(inboundRequest);

HttpUriRequest outboundRequest = new HttpGet();

Subject.doAs(subject, new PrivilegedExceptionAction<Object>() {

@Override
public Object run() throws Exception {
dispatch.copyRequestHeaderFields(outboundRequest, inboundRequest);
return null;
}
});

Header[] outboundRequestHeaders = outboundRequest.getAllHeaders();
assertThat(outboundRequestHeaders.length, is(5));
assertThat(outboundRequest.getHeaders(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME)[0].getValue(), is("knoxui"));
}

}

0 comments on commit b6ff0ac

Please sign in to comment.