Skip to content

Commit

Permalink
Merge pull request #8497 from brs96/fix-n2v
Browse files Browse the repository at this point in the history
Fix N2V progress tracking and float loss of precision
  • Loading branch information
brs96 authored Dec 8, 2023
2 parents 39eb883 + 25261cb commit cb8fbe2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ public class Node2VecModel {
private final ProgressTracker progressTracker;
private final long randomSeed;

private static final double EPSILON = 1e-10;

Node2VecModel(
LongUnaryOperator toOriginalId,
long nodeCount,
Expand Down Expand Up @@ -144,8 +146,7 @@ Result train() {
var positiveSampleProducer = new PositiveSampleProducer(
walks.iterator(partition.startNode(), partition.nodeCount()),
randomWalkProbabilities.positiveSamplingProbabilities(),
windowSize,
progressTracker
windowSize
);

return new TrainingTask(
Expand All @@ -155,7 +156,8 @@ Result train() {
negativeSamples,
learningRate,
negativeSamplingRate,
embeddingDimension
embeddingDimension,
progressTracker
);
}
);
Expand All @@ -166,7 +168,7 @@ Result train() {
.run();

double loss = tasks.stream().mapToDouble(TrainingTask::lossSum).sum();
progressTracker.logInfo(formatWithLocale("Maximum likelihood objective is %.4f", loss));
progressTracker.logInfo(formatWithLocale("Loss %.4f", loss));
lossPerIteration.add(loss);

progressTracker.endSubTask();
Expand Down Expand Up @@ -217,6 +219,8 @@ private static final class TrainingTask implements Runnable {
private final int negativeSamplingRate;
private final float learningRate;

private final ProgressTracker progressTracker;

private double lossSum;

private TrainingTask(
Expand All @@ -226,7 +230,8 @@ private TrainingTask(
NegativeSampleProducer negativeSampleProducer,
float learningRate,
int negativeSamplingRate,
int embeddingDimensions
int embeddingDimensions,
ProgressTracker progressTracker
) {
this.centerEmbeddings = centerEmbeddings;
this.contextEmbeddings = contextEmbeddings;
Expand All @@ -237,6 +242,7 @@ private TrainingTask(

this.centerGradientBuffer = new FloatVector(embeddingDimensions);
this.contextGradientBuffer = new FloatVector(embeddingDimensions);
this.progressTracker = progressTracker;
}

@Override
Expand All @@ -250,6 +256,7 @@ public void run() {
for (var i = 0; i < negativeSamplingRate; i++) {
trainSample(buffer[0], negativeSampleProducer.next(), false);
}
progressTracker.logProgress();
}
}

Expand All @@ -261,13 +268,14 @@ private void trainSample(long center, long context, boolean positive) {
// L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
float affinity = centerEmbedding.innerProduct(contextEmbedding);

float positiveSigmoid = (float) Sigmoid.sigmoid(affinity);
float negativeSigmoid = 1 - positiveSigmoid;

//When |affinity| > 40, positiveSigmoid = 1. Double precision is not enough.
//Make sure negativeSigmoid can never be 0 to avoid infinity loss.
double positiveSigmoid = Sigmoid.sigmoid(affinity) - EPSILON;
double negativeSigmoid = 1 - positiveSigmoid;

lossSum -= positive ? Math.log(positiveSigmoid) : Math.log(negativeSigmoid);

float gradient = positive ? -negativeSigmoid : positiveSigmoid;
float gradient = positive ? (float) -negativeSigmoid : (float) positiveSigmoid;
// we are doing gradient descent, so we go in the negative direction of the gradient here
float scaledGradient = -gradient * learningRate;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.neo4j.gds.embeddings.node2vec;

import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

import java.util.Iterator;
import java.util.concurrent.ThreadLocalRandom;
Expand All @@ -35,7 +34,6 @@ public class PositiveSampleProducer {
private final HugeDoubleArray samplingProbabilities;
private final int prefixWindowSize;
private final int postfixWindowSize;
private final ProgressTracker progressTracker;
private long[] currentWalk;
private int centerWordIndex;
private long currentCenterWord;
Expand All @@ -46,11 +44,9 @@ public class PositiveSampleProducer {
PositiveSampleProducer(
Iterator<long[]> walks,
HugeDoubleArray samplingProbabilities,
int windowSize,
ProgressTracker progressTracker
int windowSize
) {
this.walks = walks;
this.progressTracker = progressTracker;
this.samplingProbabilities = samplingProbabilities;

prefixWindowSize = ceilDiv(windowSize - 1, 2);
Expand All @@ -76,7 +72,6 @@ private boolean nextWalk() {
return false;
}
long[] walk = walks.next();
progressTracker.logProgress();
int filteredWalkLength = filter(walk);

while (filteredWalkLength < 2 && walks.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -60,8 +59,7 @@ void doesNotCauseStackOverflow() {
var sampleProducer = new PositiveSampleProducer(
walks.iterator(0, nbrOfWalks),
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
10,
ProgressTracker.NULL_TRACKER
10
);

var counter = 0L;
Expand All @@ -88,8 +86,7 @@ void doesNotCauseStackOverflowDueToBadLuck() {
var sampleProducer = new PositiveSampleProducer(
walks.iterator(0, nbrOfWalks),
probabilities,
10,
ProgressTracker.NULL_TRACKER
10
);
// does not overflow the stack = passes test

Expand All @@ -112,8 +109,7 @@ void doesNotAttemptToFetchOutsideBatch() {
var sampleProducer = new PositiveSampleProducer(
walks.iterator(0, nbrOfWalks / 2),
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
10,
ProgressTracker.NULL_TRACKER
10
);

var counter = 0L;
Expand All @@ -137,8 +133,7 @@ void shouldProducePairsWith(
PositiveSampleProducer producer = new PositiveSampleProducer(
walks.iterator(0, walks.size()),
centerNodeProbabilities,
windowSize,
ProgressTracker.NULL_TRACKER
windowSize
);
while (producer.next(buffer)) {
actualPairs.add(Pair.of(buffer[0], buffer[1]));
Expand All @@ -160,8 +155,7 @@ void shouldProducePairsWithBounds() {
PositiveSampleProducer producer = new PositiveSampleProducer(
walks.iterator(0, 2),
centerNodeProbabilities,
3,
ProgressTracker.NULL_TRACKER
3
);
while (producer.next(buffer)) {
actualPairs.add(Pair.of(buffer[0], buffer[1]));
Expand Down Expand Up @@ -206,8 +200,7 @@ void shouldRemoveDownsampledWordFromWalk() {
PositiveSampleProducer producer = new PositiveSampleProducer(
walks.iterator(0, walks.size()),
centerNodeProbabilities,
3,
ProgressTracker.NULL_TRACKER
3
);

while (producer.next(buffer)) {
Expand Down

0 comments on commit cb8fbe2

Please sign in to comment.