From f0231022a5498368092d66e485031f889ec5b8e3 Mon Sep 17 00:00:00 2001 From: neurolabusc Date: Sun, 28 Apr 2024 15:54:31 -0400 Subject: [PATCH] Use tfjs to compute max --- brainchop.js | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/brainchop.js b/brainchop.js index 2c27852..8a47d17 100644 --- a/brainchop.js +++ b/brainchop.js @@ -741,12 +741,6 @@ async function minMaxNormalizeVolumeData(volumeData) { return normalizedSlices_3d } -async function findArrayMax(array) { - return array.reduce((e1, e2) => { - return e1 > e2 ? e1 : e2 - }) -} - async function inferenceFullVolumeSeqCovLayer( model, slices_3d, @@ -1733,8 +1727,7 @@ async function inferenceFullVolumeSeqCovLayerPhase2( const Inference_t = ((performance.now() - startTime) / 1000).toFixed(4) console.log(' find array max ') - const curBatchMaxLabel = await findArrayMax(Array.from(outputTensor.dataSync())) - + const curBatchMaxLabel = await outputTensor.max().dataSync()[0] if (maxLabelPredicted < curBatchMaxLabel) { maxLabelPredicted = curBatchMaxLabel } @@ -2207,10 +2200,7 @@ async function inferenceFullVolumePhase2( // outputDataBeforArgmx = Array.from(prediction_argmax.dataSync()) tf.dispose(curTensor[i]) // allPredictions.push({"id": allBatches[j].id, "coordinates": allBatches[j].coordinates, "data": Array.from(prediction_argmax.dataSync()) }) - console.log(' find array max ') - // ???? await - const curBatchMaxLabel = await findArrayMax(Array.from(prediction_argmax.dataSync())) - + const curBatchMaxLabel = await prediction_argmax.max().dataSync()[0] if (maxLabelPredicted < curBatchMaxLabel) { maxLabelPredicted = curBatchMaxLabel } @@ -2597,8 +2587,7 @@ async function inferenceFullVolumePhase1( tf.dispose(curTensor[i]) console.log(' Pre-model find array max ') - const curBatchMaxLabel = await findArrayMax(Array.from(prediction_argmax.dataSync())) - + const curBatchMaxLabel = await prediction_argmax.max().dataSync()[0] if (maxLabelPredicted < curBatchMaxLabel) { maxLabelPredicted = curBatchMaxLabel }