Skip to content

Commit

Permalink
Improve JWT parse / decode performance (#620)
Browse files Browse the repository at this point in the history
* Optimise parsing of token for well-defined JWT format

* Update error message in test to match new code

* Fixing checkstyle issues

* Added missing test case for no parts

* Return new JWTDecodeException

Return a new JWTDecodeException from private utility method `wrongNumberOfParts`, instead of throwing, since we throw from `splitToken()`.

* Add JMH support to build script

* Add benchmark for decoder and cleanup build file

* Optimise JWT deserialisation by re-using threadsafe Jackson objects

* Disable lint checks on JMH source set that is for testing

* Remove extra line break

---------

Co-authored-by: Jim Anderson <[email protected]>
Co-authored-by: Jim Anderson <[email protected]>
  • Loading branch information
3 people authored Jan 31, 2023
1 parent 12ae664 commit 9024318
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 103 deletions.
46 changes: 45 additions & 1 deletion lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,28 @@ plugins {
id 'checkstyle'
}

sourceSets {
jmh {

}
}

configurations {
jmhImplementation {
extendsFrom implementation
}
}

checkstyle {
toolVersion '10.0'
checkstyleTest.enabled = false //We are disabling lint checks for tests
}
//We are disabling lint checks for tests
tasks.named("checkstyleTest").configure({
enabled = false
})
tasks.named("checkstyleJmh").configure({
enabled = false
})

logger.lifecycle("Using version ${version} for ${group}.${name}")

Expand Down Expand Up @@ -61,6 +79,10 @@ dependencies {
testImplementation 'net.jodah:concurrentunit:0.4.6'
testImplementation 'org.hamcrest:hamcrest:2.2'
testImplementation 'org.mockito:mockito-core:4.4.0'

jmhImplementation sourceSets.main.output
jmhImplementation 'org.openjdk.jmh:jmh-core:1.35'
jmhAnnotationProcessor 'org.openjdk.jmh:jmh-generator-annprocess:1.35'
}

jacoco {
Expand Down Expand Up @@ -143,3 +165,25 @@ task exportVersion() {
new File(rootDir, "version.txt").text = "$version"
}
}

// you can pass any arguments JMH accepts via Gradle args.
// Example: ./gradlew runJMH --args="-lrf"
tasks.register('runJMH', JavaExec) {
description 'Run JMH benchmarks.'
group 'verification'

main 'org.openjdk.jmh.Main'
classpath sourceSets.jmh.runtimeClasspath

args project.hasProperty("args") ? project.property("args").split() : ""
}
tasks.register('jmhHelp', JavaExec) {
description 'Prints the available command line options for JMH.'
group 'help'

main 'org.openjdk.jmh.Main'
classpath sourceSets.jmh.runtimeClasspath

args '-h'
}

20 changes: 20 additions & 0 deletions lib/src/jmh/java/com/auth0/jwt/benchmark/JWTDecoderBenchmark.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.auth0.jwt.benchmark;

import com.auth0.jwt.JWT;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.infra.Blackhole;

/**
* This class is a JMH benchmark for decoding JWTs.
*/
public class JWTDecoderBenchmark {
private static final String TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";

@Benchmark
@BenchmarkMode(Mode.Throughput)
public void throughputDecodeTime(Blackhole blackhole) {
blackhole.consume(JWT.decode(TOKEN));
}
}
13 changes: 6 additions & 7 deletions lib/src/main/java/com/auth0/jwt/impl/BasicHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.Header;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectReader;

import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static com.auth0.jwt.impl.JsonNodeClaim.extractClaim;
Expand All @@ -23,22 +22,22 @@ class BasicHeader implements Header, Serializable {
private final String contentType;
private final String keyId;
private final Map<String, JsonNode> tree;
private final ObjectReader objectReader;
private final ObjectCodec objectCodec;

BasicHeader(
String algorithm,
String type,
String contentType,
String keyId,
Map<String, JsonNode> tree,
ObjectReader objectReader
ObjectCodec objectCodec
) {
this.algorithm = algorithm;
this.type = type;
this.contentType = contentType;
this.keyId = keyId;
this.tree = Collections.unmodifiableMap(tree == null ? new HashMap<>() : tree);
this.objectReader = objectReader;
this.tree = tree == null ? Collections.emptyMap() : Collections.unmodifiableMap(tree);
this.objectCodec = objectCodec;
}

Map<String, JsonNode> getTree() {
Expand Down Expand Up @@ -67,6 +66,6 @@ public String getKeyId() {

@Override
public Claim getHeaderClaim(String name) {
return extractClaim(name, tree, objectReader);
return extractClaim(name, tree, objectCodec);
}
}
20 changes: 6 additions & 14 deletions lib/src/main/java/com/auth0/jwt/impl/HeaderDeserializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import com.auth0.jwt.HeaderParams;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.interfaces.Header;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;

import java.io.IOException;
Expand All @@ -19,22 +19,14 @@
*
* @see JWTParser
*/
class HeaderDeserializer extends StdDeserializer<BasicHeader> {
class HeaderDeserializer extends StdDeserializer<Header> {

private final ObjectReader objectReader;

HeaderDeserializer(ObjectReader objectReader) {
this(null, objectReader);
}

private HeaderDeserializer(Class<?> vc, ObjectReader objectReader) {
super(vc);

this.objectReader = objectReader;
HeaderDeserializer() {
super(Header.class);
}

@Override
public BasicHeader deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
public Header deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
Map<String, JsonNode> tree = p.getCodec().readValue(p, new TypeReference<Map<String, JsonNode>>() {
});
if (tree == null) {
Expand All @@ -45,7 +37,7 @@ public BasicHeader deserialize(JsonParser p, DeserializationContext ctxt) throws
String type = getString(tree, HeaderParams.TYPE);
String contentType = getString(tree, HeaderParams.CONTENT_TYPE);
String keyId = getString(tree, HeaderParams.KEY_ID);
return new BasicHeader(algorithm, type, contentType, keyId, tree, objectReader);
return new BasicHeader(algorithm, type, contentType, keyId, tree, p.getCodec());
}

String getString(Map<String, JsonNode> tree, String claimName) {
Expand Down
22 changes: 17 additions & 5 deletions lib/src/main/java/com/auth0/jwt/impl/JWTParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
* {@link HeaderSerializer} and {@link PayloadSerializer}.
*/
public class JWTParser implements JWTPartsParser {
private static final ObjectMapper DEFAULT_OBJECT_MAPPER = createDefaultObjectMapper();
private static final ObjectReader DEFAULT_PAYLOAD_READER = DEFAULT_OBJECT_MAPPER.readerFor(Payload.class);
private static final ObjectReader DEFAULT_HEADER_READER = DEFAULT_OBJECT_MAPPER.readerFor(Header.class);

private final ObjectReader payloadReader;
private final ObjectReader headerReader;

public JWTParser() {
this(getDefaultObjectMapper());
this.payloadReader = DEFAULT_PAYLOAD_READER;
this.headerReader = DEFAULT_HEADER_READER;
}

JWTParser(ObjectMapper mapper) {
addDeserializers(mapper);

this.payloadReader = mapper.readerFor(Payload.class);
this.headerReader = mapper.readerFor(Header.class);
}
Expand Down Expand Up @@ -55,18 +61,24 @@ public Header parseHeader(String json) throws JWTDecodeException {
}
}

private void addDeserializers(ObjectMapper mapper) {
static void addDeserializers(ObjectMapper mapper) {
SimpleModule module = new SimpleModule();
ObjectReader reader = mapper.reader();
module.addDeserializer(Payload.class, new PayloadDeserializer(reader));
module.addDeserializer(Header.class, new HeaderDeserializer(reader));
module.addDeserializer(Payload.class, new PayloadDeserializer());
module.addDeserializer(Header.class, new HeaderDeserializer());
mapper.registerModule(module);
}

static ObjectMapper getDefaultObjectMapper() {
return DEFAULT_OBJECT_MAPPER;
}

private static ObjectMapper createDefaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);
mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY);

addDeserializers(mapper);

return mapper;
}

Expand Down
36 changes: 19 additions & 17 deletions lib/src/main/java/com/auth0/jwt/impl/JsonNodeClaim.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import com.auth0.jwt.interfaces.Claim;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectReader;

import java.io.IOException;
import java.lang.reflect.Array;
Expand All @@ -21,12 +21,12 @@
*/
class JsonNodeClaim implements Claim {

private final ObjectReader objectReader;
private final ObjectCodec codec;
private final JsonNode data;

private JsonNodeClaim(JsonNode node, ObjectReader objectReader) {
private JsonNodeClaim(JsonNode node, ObjectCodec codec) {
this.data = node;
this.objectReader = objectReader;
this.codec = codec;
}

@Override
Expand Down Expand Up @@ -82,7 +82,7 @@ public <T> T[] asArray(Class<T> clazz) throws JWTDecodeException {
T[] arr = (T[]) Array.newInstance(clazz, data.size());
for (int i = 0; i < data.size(); i++) {
try {
arr[i] = objectReader.treeToValue(data.get(i), clazz);
arr[i] = codec.treeToValue(data.get(i), clazz);
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim's array contents to " + clazz.getSimpleName(), e);
}
Expand All @@ -99,7 +99,7 @@ public <T> List<T> asList(Class<T> clazz) throws JWTDecodeException {
List<T> list = new ArrayList<>();
for (int i = 0; i < data.size(); i++) {
try {
list.add(objectReader.treeToValue(data.get(i), clazz));
list.add(codec.treeToValue(data.get(i), clazz));
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim's array contents to " + clazz.getSimpleName(), e);
}
Expand All @@ -113,11 +113,11 @@ public Map<String, Object> asMap() throws JWTDecodeException {
return null;
}

try {
TypeReference<Map<String, Object>> mapType = new TypeReference<Map<String, Object>>() {
};
JsonParser thisParser = objectReader.treeAsTokens(data);
return thisParser.readValueAs(mapType);
TypeReference<Map<String, Object>> mapType = new TypeReference<Map<String, Object>>() {
};

try (JsonParser parser = codec.treeAsTokens(data)) {
return parser.readValueAs(mapType);
} catch (IOException e) {
throw new JWTDecodeException("Couldn't map the Claim value to Map", e);
}
Expand All @@ -129,8 +129,8 @@ public <T> T as(Class<T> clazz) throws JWTDecodeException {
if (isMissing() || isNull()) {
return null;
}
return objectReader.treeAsTokens(data).readValueAs(clazz);
} catch (IOException e) {
return codec.treeToValue(data, clazz);
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim value to " + clazz.getSimpleName(), e);
}
}
Expand Down Expand Up @@ -160,21 +160,23 @@ public String toString() {
*
* @param claimName the Claim to search for.
* @param tree the JsonNode tree to search the Claim in.
* @param objectCodec the object codec in use for deserialization
* @return a valid non-null Claim.
*/
static Claim extractClaim(String claimName, Map<String, JsonNode> tree, ObjectReader objectReader) {
static Claim extractClaim(String claimName, Map<String, JsonNode> tree, ObjectCodec objectCodec) {
JsonNode node = tree.get(claimName);
return claimFromNode(node, objectReader);
return claimFromNode(node, objectCodec);
}

/**
* Helper method to create a Claim representation from the given JsonNode.
*
* @param node the JsonNode to convert into a Claim.
* @param objectCodec the object codec in use for deserialization
* @return a valid Claim instance. If the node is null or missing, a NullClaim will be returned.
*/
static Claim claimFromNode(JsonNode node, ObjectReader objectReader) {
return new JsonNodeClaim(node, objectReader);
static Claim claimFromNode(JsonNode node, ObjectCodec objectCodec) {
return new JsonNodeClaim(node, objectCodec);
}

}
22 changes: 8 additions & 14 deletions lib/src/main/java/com/auth0/jwt/impl/PayloadDeserializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.auth0.jwt.interfaces.Payload;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
Expand All @@ -24,16 +25,8 @@
*/
class PayloadDeserializer extends StdDeserializer<Payload> {

private final ObjectReader objectReader;

PayloadDeserializer(ObjectReader reader) {
this(null, reader);
}

private PayloadDeserializer(Class<?> vc, ObjectReader reader) {
super(vc);

this.objectReader = reader;
PayloadDeserializer() {
super(Payload.class);
}

@Override
Expand All @@ -46,16 +39,17 @@ public Payload deserialize(JsonParser p, DeserializationContext ctxt) throws IOE

String issuer = getString(tree, RegisteredClaims.ISSUER);
String subject = getString(tree, RegisteredClaims.SUBJECT);
List<String> audience = getStringOrArray(tree, RegisteredClaims.AUDIENCE);
List<String> audience = getStringOrArray(p.getCodec(), tree, RegisteredClaims.AUDIENCE);
Instant expiresAt = getInstantFromSeconds(tree, RegisteredClaims.EXPIRES_AT);
Instant notBefore = getInstantFromSeconds(tree, RegisteredClaims.NOT_BEFORE);
Instant issuedAt = getInstantFromSeconds(tree, RegisteredClaims.ISSUED_AT);
String jwtId = getString(tree, RegisteredClaims.JWT_ID);

return new PayloadImpl(issuer, subject, audience, expiresAt, notBefore, issuedAt, jwtId, tree, objectReader);
return new PayloadImpl(issuer, subject, audience, expiresAt, notBefore, issuedAt, jwtId, tree, p.getCodec());
}

List<String> getStringOrArray(Map<String, JsonNode> tree, String claimName) throws JWTDecodeException {
List<String> getStringOrArray(ObjectCodec codec, Map<String, JsonNode> tree, String claimName)
throws JWTDecodeException {
JsonNode node = tree.get(claimName);
if (node == null || node.isNull() || !(node.isArray() || node.isTextual())) {
return null;
Expand All @@ -67,7 +61,7 @@ List<String> getStringOrArray(Map<String, JsonNode> tree, String claimName) thro
List<String> list = new ArrayList<>(node.size());
for (int i = 0; i < node.size(); i++) {
try {
list.add(objectReader.treeToValue(node.get(i), String.class));
list.add(codec.treeToValue(node.get(i), String.class));
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim's array contents to String", e);
}
Expand Down
Loading

0 comments on commit 9024318

Please sign in to comment.