Skip to content

Commit

Permalink
Merge pull request #15 from studio-YOLO/7-k-mean-optimization
Browse files Browse the repository at this point in the history
K-means++ optimization
  • Loading branch information
VinciGit00 authored Mar 7, 2024
2 parents f2b3ae4 + 8dd37a4 commit 0a230d3
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 32 deletions.
3 changes: 2 additions & 1 deletion demo/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
<script type="module" src="scripts/change_saturation_demo.js"></script>
<script type="module" src="scripts/change_brightness_demo.js"></script>
<script type="module" src="scripts/sepia_demo.js"></script>
<script type="module" src="scripts/resize_image_demo.js"></script>-->
<script type="module" src="scripts/change_opacity_demo.js"></script>
<script type="module" src="scripts/resize_image_demo.js"></script>-->
<script type="module" src="scripts/color_palette_wasm_demo.js"></script>
</body>

</html>
11 changes: 4 additions & 7 deletions demo/scripts/color_palette_wasm_demo.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@ container.classList.add("container")

//waiting image load
image.onload = () => {
//calculate color palette
editpix.getColorPaletteWasm(image, 13, 1)
console.log("Image loaded.");

editpix.getColorPaletteWasm(image, 15, 2)
.then(colorPalette => {
console.log(colorPalette)
displayPalette(colorPalette);
})




});
};


Expand Down
12 changes: 3 additions & 9 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,11 @@ edition = "2021"

[dependencies]
wasm-bindgen = "0.2"
getrandom = { version = "0.2", features = ["js"] }
rand = "*"

[lib]
crate-type = ["cdylib", "rlib"]

[dependencies.web-sys]
version = "0.3.68"
features = [
'Document',
'Element',
'HtmlElement',
'Node',
'Window',
]


3 changes: 0 additions & 3 deletions lib/src/.gitignore

This file was deleted.

98 changes: 93 additions & 5 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
extern crate wasm_bindgen;
use wasm_bindgen::prelude::*;
extern crate web_sys;
use rand::prelude::*;

#[inline]
fn euclidean_distance(color1: &[u8; 3], color2: &[u8; 3]) -> f64 {
f64::sqrt(
((color1[0] as i16 - color2[0] as i16) as f64).powi(2) +
((color1[1] as i16 - color2[1] as i16) as f64).powi(2) +
((color1[2] as i16 - color2[2] as i16) as f64).powi(2)
(color1[0] as f64 - color2[0] as f64).powi(2) +
(color1[1] as f64 - color2[1] as f64).powi(2) +
(color1[2] as f64 - color2[2] as f64).powi(2)
)
}

Expand Down Expand Up @@ -105,4 +105,92 @@ pub fn k_means(colors_r: Vec<u8>, color_number: usize, max_iterations: usize) ->
.collect();

serialized_vector
}
}

fn square_distance(color1: &[u8; 3], color2: &[u8; 3]) -> f64 {
(color1[0] as f64 - color2[0] as f64).powi(2) +
(color1[1] as f64 - color2[1] as f64).powi(2) +
(color1[2] as f64 - color2[2] as f64).powi(2)
}

fn initialize_centroids_pp(colors: &Vec<[u8; 3]>, color_number: usize) -> Vec<[u8; 3]> {
let mut rng = thread_rng();
let mut centroids: Vec<[u8; 3]> = Vec::new();
let first_centroid = colors[rng.gen_range(0..colors.len())];
centroids.push(first_centroid);

for _i in 1..color_number {
let distances = colors.iter().map(|x| {
let partial_distances = centroids.iter().map(|y| square_distance(x, y));
partial_distances.fold(f64::INFINITY, |a, b| a.min(b))
});
let total_weight = distances.clone().fold(0.0, |a, b| a + b);
let distances: Vec<_> = distances.collect();
let target = rng.gen::<f64>() * total_weight;
let mut cumulative = 0.0;
for i in 0..colors.len() {
cumulative += distances[i];
if cumulative >= target {
centroids.push(colors[i]);
break;
}
}
}

centroids
}

fn assign_to_centroids_pp(colors: &[[u8; 3]], centroids: Vec<[u8; 3]>) -> Vec<usize> {
let mut assignments: Vec<usize> = Vec::new();
for i in 0..colors.len() {
let mut min_distance = f64::INFINITY;
let mut closest_centroid = 0;
for j in 0..centroids.len() {
let distance = square_distance(colors.get(i).unwrap(), centroids.get(j).unwrap());
if distance < min_distance {
min_distance = distance;
closest_centroid = j;
}
}
assignments.push(closest_centroid);
}

assignments
}

#[wasm_bindgen]
pub fn k_means_pp(colors_r: Vec<u8>, color_number: usize, max_iterations: usize) -> Vec<u8> {
let colors: Vec<[u8; 3]> = colors_r
.chunks_exact(3) // Get chunks of 3 elements
.map(|chunk| {
let mut array: [u8; 3] = [0; 3];
array.copy_from_slice(chunk);
array
})
.collect();


let mut centroids = initialize_centroids_pp(&colors, color_number);

let mut iterations: usize = 0;
let mut previous_assignments;
let mut assignments: Vec<usize> = Vec::new();

loop {
previous_assignments = assignments;
assignments = assign_to_centroids_pp(&colors, centroids);
centroids = calculate_new_centroids(&colors, &assignments, color_number);
iterations += 1;
if iterations > max_iterations || assignments == previous_assignments {
break;
}
}

let serialized_vector: Vec<u8> = centroids
.into_iter()
.flat_map(|array| array.into_iter())
.collect();

serialized_vector
}

172 changes: 172 additions & 0 deletions src/core/editpix_wasm.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
let wasm;

const heap = new Array(128).fill(undefined);

heap.push(undefined, null, true, false);

function getObject(idx) { return heap[idx]; }

let heap_next = heap.length;

function dropObject(idx) {
if (idx < 132) return;
heap[idx] = heap_next;
heap_next = idx;
}

function takeObject(idx) {
const ret = getObject(idx);
dropObject(idx);
return ret;
}

const cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } );

if (typeof TextDecoder !== 'undefined') { cachedTextDecoder.decode(); };
Expand All @@ -18,6 +38,15 @@ function getStringFromWasm0(ptr, len) {
return cachedTextDecoder.decode(getUint8Memory0().subarray(ptr, ptr + len));
}

function addHeapObject(obj) {
if (heap_next === heap.length) heap.push(heap.length + 1);
const idx = heap_next;
heap_next = heap[idx];

heap[idx] = obj;
return idx;
}

let WASM_VECTOR_LEN = 0;

function passArray8ToWasm0(arg, malloc) {
Expand Down Expand Up @@ -62,6 +91,36 @@ export function k_means(colors_r, color_number, max_iterations) {
}
}

/**
* @param {Uint8Array} colors_r
* @param {number} color_number
* @param {number} max_iterations
* @returns {Uint8Array}
*/
export function k_means_pp(colors_r, color_number, max_iterations) {
try {
const retptr = wasm.__wbindgen_add_to_stack_pointer(-16);
const ptr0 = passArray8ToWasm0(colors_r, wasm.__wbindgen_malloc);
const len0 = WASM_VECTOR_LEN;
wasm.k_means_pp(retptr, ptr0, len0, color_number, max_iterations);
var r0 = getInt32Memory0()[retptr / 4 + 0];
var r1 = getInt32Memory0()[retptr / 4 + 1];
var v2 = getArrayU8FromWasm0(r0, r1).slice();
wasm.__wbindgen_free(r0, r1 * 1, 1);
return v2;
} finally {
wasm.__wbindgen_add_to_stack_pointer(16);
}
}

function handleError(f, args) {
try {
return f.apply(this, args);
} catch (e) {
wasm.__wbindgen_exn_store(addHeapObject(e));
}
}

async function __wbg_load(module, imports) {
if (typeof Response === 'function' && module instanceof Response) {
if (typeof WebAssembly.instantiateStreaming === 'function') {
Expand Down Expand Up @@ -96,9 +155,122 @@ async function __wbg_load(module, imports) {
function __wbg_get_imports() {
const imports = {};
imports.wbg = {};
imports.wbg.__wbg_crypto_d05b68a3572bb8ca = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
};
imports.wbg.__wbindgen_is_object = function(arg0) {
const val = getObject(arg0);
const ret = typeof(val) === 'object' && val !== null;
return ret;
};
imports.wbg.__wbg_process_b02b3570280d0366 = function(arg0) {
const ret = getObject(arg0).process;
return addHeapObject(ret);
};
imports.wbg.__wbg_versions_c1cb42213cedf0f5 = function(arg0) {
const ret = getObject(arg0).versions;
return addHeapObject(ret);
};
imports.wbg.__wbg_node_43b1089f407e4ec2 = function(arg0) {
const ret = getObject(arg0).node;
return addHeapObject(ret);
};
imports.wbg.__wbindgen_is_string = function(arg0) {
const ret = typeof(getObject(arg0)) === 'string';
return ret;
};
imports.wbg.__wbindgen_object_drop_ref = function(arg0) {
takeObject(arg0);
};
imports.wbg.__wbg_msCrypto_10fc94afee92bd76 = function(arg0) {
const ret = getObject(arg0).msCrypto;
return addHeapObject(ret);
};
imports.wbg.__wbg_require_9a7e0f667ead4995 = function() { return handleError(function () {
const ret = module.require;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbindgen_is_function = function(arg0) {
const ret = typeof(getObject(arg0)) === 'function';
return ret;
};
imports.wbg.__wbindgen_string_new = function(arg0, arg1) {
const ret = getStringFromWasm0(arg0, arg1);
return addHeapObject(ret);
};
imports.wbg.__wbg_randomFillSync_b70ccbdf4926a99d = function() { return handleError(function (arg0, arg1) {
getObject(arg0).randomFillSync(takeObject(arg1));
}, arguments) };
imports.wbg.__wbg_getRandomValues_7e42b4fb8779dc6d = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };
imports.wbg.__wbg_newnoargs_e258087cd0daa0ea = function(arg0, arg1) {
const ret = new Function(getStringFromWasm0(arg0, arg1));
return addHeapObject(ret);
};
imports.wbg.__wbg_call_27c0f87801dedf93 = function() { return handleError(function (arg0, arg1) {
const ret = getObject(arg0).call(getObject(arg1));
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbindgen_object_clone_ref = function(arg0) {
const ret = getObject(arg0);
return addHeapObject(ret);
};
imports.wbg.__wbg_self_ce0dbfc45cf2f5be = function() { return handleError(function () {
const ret = self.self;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_window_c6fb939a7f436783 = function() { return handleError(function () {
const ret = window.window;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_globalThis_d1e6af4856ba331b = function() { return handleError(function () {
const ret = globalThis.globalThis;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_global_207b558942527489 = function() { return handleError(function () {
const ret = global.global;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbindgen_is_undefined = function(arg0) {
const ret = getObject(arg0) === undefined;
return ret;
};
imports.wbg.__wbg_call_b3ca7c6051f9bec1 = function() { return handleError(function (arg0, arg1, arg2) {
const ret = getObject(arg0).call(getObject(arg1), getObject(arg2));
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_buffer_12d079cc21e14bdb = function(arg0) {
const ret = getObject(arg0).buffer;
return addHeapObject(ret);
};
imports.wbg.__wbg_newwithbyteoffsetandlength_aa4a17c33a06e5cb = function(arg0, arg1, arg2) {
const ret = new Uint8Array(getObject(arg0), arg1 >>> 0, arg2 >>> 0);
return addHeapObject(ret);
};
imports.wbg.__wbg_new_63b92bc8671ed464 = function(arg0) {
const ret = new Uint8Array(getObject(arg0));
return addHeapObject(ret);
};
imports.wbg.__wbg_set_a47bac70306a19a7 = function(arg0, arg1, arg2) {
getObject(arg0).set(getObject(arg1), arg2 >>> 0);
};
imports.wbg.__wbg_newwithlength_e9b4878cebadb3d3 = function(arg0) {
const ret = new Uint8Array(arg0 >>> 0);
return addHeapObject(ret);
};
imports.wbg.__wbg_subarray_a1f73cd4b5b42fe1 = function(arg0, arg1, arg2) {
const ret = getObject(arg0).subarray(arg1 >>> 0, arg2 >>> 0);
return addHeapObject(ret);
};
imports.wbg.__wbindgen_throw = function(arg0, arg1) {
throw new Error(getStringFromWasm0(arg0, arg1));
};
imports.wbg.__wbindgen_memory = function() {
const ret = wasm.memory;
return addHeapObject(ret);
};

return imports;
}
Expand Down
Binary file modified src/core/editpix_wasm_bg.wasm
Binary file not shown.
Loading

0 comments on commit 0a230d3

Please sign in to comment.