Skip to content

Commit

Permalink
Merge pull request #8666 from yuvalr1neo/gds-session-procedure-collec…
Browse files Browse the repository at this point in the history
…tion

Gds session procedure collection
  • Loading branch information
jjaderberg authored Feb 13, 2024
2 parents 84f0321 + 6dbd674 commit 0b4e14d
Show file tree
Hide file tree
Showing 5 changed files with 811 additions and 4 deletions.
1 change: 1 addition & 0 deletions proc/sysinfo/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ group = 'org.neo4j.gds'

dependencies {
annotationProcessor project(':annotations')
annotationProcessor project(':procedure-collector')
annotationProcessor group: 'org.immutables', name: 'value', version: ver.'immutables'

compileOnly project(':annotations')
Expand Down
9 changes: 5 additions & 4 deletions procedure-collector/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ dependencies {

implementation project(':annotations')
implementation project(':executor')
implementation group: 'org.immutables', name: 'value-annotations', version: ver.'immutables'
implementation group: 'com.google.auto', name: 'auto-common', version: ver.'auto-common'
implementation group: 'com.squareup', name: 'javapoet', version: ver.'javapoet'
implementation group: 'org.jetbrains', name: 'annotations', version: ver.'jetbrains-annotations'
implementation group: 'org.neo4j', name: 'neo4j-procedure-api', version: ver.neo4j
implementation group: 'org.immutables', name: 'value-annotations', version: ver.'immutables'
implementation group: 'com.google.auto', name: 'auto-common', version: ver.'auto-common'
implementation group: 'com.squareup', name: 'javapoet', version: ver.'javapoet'
implementation group: 'org.jetbrains', name: 'annotations', version: ver.'jetbrains-annotations'
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.proc;

import com.google.auto.common.BasicAnnotationProcessor;
import com.google.auto.service.AutoService;

import javax.annotation.processing.Processor;
import javax.annotation.processing.RoundEnvironment;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.TypeElement;
import javax.tools.Diagnostic;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;

import static javax.tools.StandardLocation.CLASS_OUTPUT;

/**
* An annotation processor that creates files to enable service loading for procedures, user functions and aggregations.
*
* Only things listed in the session allow list will be written (and thus loaded).
*/
@AutoService(Processor.class)
public class SessionProcedureCollectorProcessor extends BasicAnnotationProcessor {

private final Set<TypeElement> proceduresWithAnnotations = new HashSet<>();
private final Set<TypeElement> functionsWithAnnotations = new HashSet<>();
private final Set<TypeElement> aggregationsWithAnnotations = new HashSet<>();

@Override
public SourceVersion getSupportedSourceVersion() {
return SourceVersion.RELEASE_17;
}

@Override
protected Iterable<? extends Step> steps() {
return List.of(new SessionProcedureCollectorStep(
proceduresWithAnnotations,
functionsWithAnnotations,
aggregationsWithAnnotations
));
}

@Override
protected void postRound(RoundEnvironment roundEnv) {
if (roundEnv.processingOver()) {
tryWriteElements();
}
}

private void tryWriteElements() {
try {
writeElements();
proceduresWithAnnotations.clear();
} catch (IOException e) {
logError(e,
String.format(
Locale.ENGLISH,
"Failed to write procedures for service loading. First: %s",
proceduresWithAnnotations
)
);
}
}

private void writeElements() throws IOException {
if (!proceduresWithAnnotations.isEmpty()) {
writeElementsOfType(SessionProcedureCollectorStep.PROCEDURE, proceduresWithAnnotations);
}
if (!functionsWithAnnotations.isEmpty()) {
writeElementsOfType(SessionProcedureCollectorStep.USER_FUNCTION, functionsWithAnnotations);
}
if (!aggregationsWithAnnotations.isEmpty()) {
writeElementsOfType(SessionProcedureCollectorStep.USER_AGGREGATION, aggregationsWithAnnotations);
}
}

private void writeElementsOfType(String typeName, Iterable<TypeElement> elements) throws IOException {
// we fake being a service so that we get properly merged in the shadow jar
var path = "META-INF/services/" + typeName;
var file = processingEnv.getFiler().createResource(CLASS_OUTPUT, "", path);

try (var writer = new PrintWriter(
new BufferedOutputStream(file.openOutputStream()),
true,
StandardCharsets.UTF_8
)) {
for (var element : elements) {
writer.println(element.getQualifiedName());
}
}
}

private void logError(Exception e, String message) {
processingEnv.getMessager().printMessage(
Diagnostic.Kind.ERROR,
String.format(
Locale.ENGLISH,
message,
e.getMessage()
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.proc;

import com.google.auto.common.BasicAnnotationProcessor;
import com.google.auto.common.MoreElements;
import com.google.common.collect.ImmutableSetMultimap;
import org.neo4j.procedure.Procedure;
import org.neo4j.procedure.UserAggregationFunction;
import org.neo4j.procedure.UserFunction;

import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

public class SessionProcedureCollectorStep implements BasicAnnotationProcessor.Step {
static final String PROCEDURE = Procedure.class.getCanonicalName();
static final String USER_FUNCTION = UserFunction.class.getCanonicalName();
static final String USER_AGGREGATION = UserAggregationFunction.class.getCanonicalName();

private final Set<TypeElement> procedures;
private final Set<TypeElement> functions;
private final Set<TypeElement> aggregations;

SessionProcedureCollectorStep(
Set<TypeElement> outProcedures,
Set<TypeElement> outFunctions,
Set<TypeElement> outAggregations
) {
this.procedures = outProcedures;
this.functions = outFunctions;
this.aggregations = outAggregations;
}

@Override
public Set<String> annotations() {
return Set.of(PROCEDURE, USER_FUNCTION, USER_AGGREGATION);
}

@Override
public Set<? extends Element> process(ImmutableSetMultimap<String, Element> elementsByAnnotation) {
var allProcedures = elementsByAnnotation.get(PROCEDURE);
for (var procedure : allProcedures) {
if(isInPackage(procedure)) {
var annotation = Arrays.stream(procedure.getAnnotationsByType(Procedure.class)).toList().get(0);
var elementName = annotation.value().isEmpty() ? annotation.name() : annotation.value();
if (isAllowed(elementName)) {
procedures.add(MoreElements.asType(procedure.getEnclosingElement()));
}
}
}

var allFunctions = elementsByAnnotation.get(USER_FUNCTION);
for (var function : allFunctions) {
if(isInPackage(function)) {
var annotation = Arrays.stream(function.getAnnotationsByType(UserFunction.class)).toList().get(0);
if (isAllowed(annotation.value().isEmpty() ? annotation.name() : annotation.value())) {
functions.add(MoreElements.asType(function.getEnclosingElement()));
}
}
}

var allAggregations = elementsByAnnotation.get(USER_AGGREGATION);
for (var aggregation : allAggregations) {
if(isInPackage(aggregation)) {
var annotation = Arrays.stream(aggregation.getAnnotationsByType(UserAggregationFunction.class)).toList().get(0);
if (isAllowed(annotation.value().isEmpty() ? annotation.name() : annotation.value())) {
aggregations.add(MoreElements.asType(aggregation.getEnclosingElement()));
}
}
}

return new HashSet<>();
}

private boolean isInPackage(Element element) {
var thePackage = MoreElements.getPackage(element);
var packageName = thePackage.getQualifiedName().toString();
return packageName.startsWith("org.neo4j.gds.") || packageName.equals("org.neo4j.gds");
}

private boolean isAllowed(String elementName) {
// TODO: Let's do allow list validation in integration test for sessions, not while processing annotations
return true;
}

}
Loading

0 comments on commit 0b4e14d

Please sign in to comment.