Skip to content

Commit

Permalink
update: implemented median cut option for wasm palette extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
f-aguzzi committed Mar 15, 2024
1 parent a70ba5c commit d208222
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 2 deletions.
2 changes: 1 addition & 1 deletion demo/scripts/color_palette_wasm_demo.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ container.classList.add("container")
image.onload = () => {
console.log("Image loaded.");

editpix.getColorPaletteWasm(image, 15, 2)
editpix.getColorPaletteWasm(image, 15, 2, "median cut")
.then(colorPalette => {
console.log(colorPalette)
displayPalette(colorPalette);
Expand Down
99 changes: 99 additions & 0 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,102 @@ pub fn k_means_pp(colors_r: Vec<u8>, color_number: usize, max_iterations: usize)
serialized_vector
}




// Sort an RGB pixel array by its channel with the highest variance
fn sort_pixels(pixels: &Vec<[u8; 3]>, channel: usize) -> Vec<[u8; 3]> {
let mut px = pixels.to_owned();
px.sort_by(|a, b| a[channel].cmp(&b[channel]));
px
}

// Find the color channel with the highest variance
fn find_max_channel(pixels: &Vec<[u8; 3]>) -> usize {
let mut min = [255, 255, 255];
let mut max = [0, 0, 0];
for (_i, pixel) in pixels.iter().enumerate() {
for j in 0..3 {
if pixel[j] < min[j] {
min[j] = pixel[j];
}
if pixel[j] > max[j] {
max[j] = pixel[j];
}
}
}
let mut range = [0, 0, 0];
for j in 0..3 {
range[j] = max[j] - min[j];
}
let max_channel: usize = (0..3).into_iter().max().unwrap();
max_channel
}

// Find the average color of an RGB pixel array
fn find_average_color(pixels: Vec<[u8; 3]>) -> [u8; 3] {
let mut sum: [f32; 3] = [0.0, 0.0, 0.0];
for pixel in &pixels {
for j in 0..3 {
sum[j] += pixel[j] as f32;
}
}
let avg: [u8; 3] = [(sum[0] / pixels.len() as f32) as u8, (sum[1] / pixels.len() as f32) as u8, (sum[2] / pixels.len() as f32) as u8];
avg
}

// Apply the median cut algorithm to an RGB pixel array and return a downsized color palette
#[wasm_bindgen]
pub fn median_cut(pixels_r: Vec<u8>, palette_size: usize) -> Vec<u8> {
// Turn the linear array into an array of RGB arrays
let pixels: Vec<[u8; 3]> = pixels_r
.chunks_exact(3) // Get chunks of 3 elements
.map(|chunk| {
let mut array: [u8; 3] = [0; 3];
array.copy_from_slice(chunk);
array
})
.collect();

// Initialize a queue of regions with all pixels
let mut queue = vec![pixels];
// Repeat the following loop until the queue reaches the correct size
while queue.len() < palette_size {
// Extract the region with the most pixels from the queue
let mut max_index = 0;
let mut max_size = 0;
for (i, region) in queue.iter().enumerate() {
if region.len() > max_size {
max_size = region.len();
max_index = i;
}
}
let region = queue.remove(max_index);
// Find the channel with the highest variance within the region
let channel = find_max_channel(&region);
// Sort the pixels in the region by that channel
let sorted: Vec<[u8; 3]> = sort_pixels(&region, channel).iter().cloned().collect();
// Find the average and bisect the region
let median = sorted.len() / 2;
let left = &sorted[..median];
let right = &sorted[median..];
// Add the two regions to the queue
queue.push(left.to_vec());
queue.push(right.to_vec());
}
// Compute the average color of each region and return the palette
let mut palette: Vec<[u8; 3]> = Vec::new();
for region in queue {
let color = find_average_color(region);
palette.push(color);
}

// Serialize the array of arrays to get a linear array
let serialized_vector: Vec<u8> = palette
.into_iter()
.flat_map(|array| array.into_iter())
.collect();

serialized_vector
}

21 changes: 21 additions & 0 deletions src/core/editpix_wasm.js
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,27 @@ export function k_means_pp(colors_r, color_number, max_iterations) {
}
}

/**
* @param {Uint8Array} pixels_r
* @param {number} palette_size
* @returns {Uint8Array}
*/
export function median_cut(pixels_r, palette_size) {
try {
const retptr = wasm.__wbindgen_add_to_stack_pointer(-16);
const ptr0 = passArray8ToWasm0(pixels_r, wasm.__wbindgen_malloc);
const len0 = WASM_VECTOR_LEN;
wasm.median_cut(retptr, ptr0, len0, palette_size);
var r0 = getInt32Memory0()[retptr / 4 + 0];
var r1 = getInt32Memory0()[retptr / 4 + 1];
var v2 = getArrayU8FromWasm0(r0, r1).slice();
wasm.__wbindgen_free(r0, r1 * 1, 1);
return v2;
} finally {
wasm.__wbindgen_add_to_stack_pointer(16);
}
}

function handleError(f, args) {
try {
return f.apply(this, args);
Expand Down
Binary file modified src/core/editpix_wasm_bg.wasm
Binary file not shown.
6 changes: 5 additions & 1 deletion src/editpix.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import convertToBW from "./core/black_and_white.js";
import kMeans from "./core/kmean.js";
import imageManager from "./image_manager.js";
import higherColorContrast from "./core/higher_contrast.js";
import init, { k_means, k_means_pp } from "./core/editpix_wasm.js"
import init, { k_means, k_means_pp, median_cut } from "./core/editpix_wasm.js"
import optimizeContrast from "./core/optimize_contrast.js";
import changeContrast from "./core/change_contrast.js";
import changeTemperature from "./core/change_temperature.js";
Expand Down Expand Up @@ -41,6 +41,10 @@ EditPix.prototype.getColorPaletteWasm = async (image, colorNumber = 5, quality =
init().then(() => {
resolve(utils.deserializeArray(k_means_pp(pixelArray, colorNumber, 100)));
})
} else if (algorithm === "median cut") {
init().then(() => {
resolve(utils.deserializeArray(median_cut(pixelArray, colorNumber)));
})
} else {
throw new Error("Non-existent algorithm.");
}
Expand Down

0 comments on commit d208222

Please sign in to comment.