Skip to content

Commit

Permalink
feat: added combine 2 styles feature
Browse files Browse the repository at this point in the history
  • Loading branch information
sidhant-sriv committed Jul 1, 2024
1 parent a793a38 commit 1b9cf7c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions client.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ class StyleTransfer {
}

async loadModels() {
//Change models accordingly
this.styleNet = await tf.loadGraphModel('file://./saved_model_style_inception_js/model.json');
this.transformNet = await tf.loadGraphModel('file://./saved_model_transformer_js/model.json');
}
Expand Down Expand Up @@ -44,6 +43,34 @@ class StyleTransfer {
stylized.dispose();
}

async combineStyles(contentPath, style1Path, style2Path, outputPath, styleRatio = 0.5) {
await this.loadModels();

const content = await this.loadImage(contentPath);
const style1 = await this.loadImage(style1Path);
const style2 = await this.loadImage(style2Path);

const stylized = await tf.tidy(() => {
const contentTensor = content.toFloat().div(tf.scalar(255)).expandDims();
const style1Tensor = style1.toFloat().div(tf.scalar(255)).expandDims();
const style2Tensor = style2.toFloat().div(tf.scalar(255)).expandDims();

const bottleneck1 = this.styleNet.predict(style1Tensor);
const bottleneck2 = this.styleNet.predict(style2Tensor);

const combinedBottleneck = tf.tidy(() => {
const scaledBottleneck1 = bottleneck1.mul(tf.scalar(1 - styleRatio));
const scaledBottleneck2 = bottleneck2.mul(tf.scalar(styleRatio));
return scaledBottleneck1.add(scaledBottleneck2);
});

return this.transformNet.predict([contentTensor, combinedBottleneck]).squeeze();
});

await this.saveImage(stylized, outputPath);
stylized.dispose();
}

async loadImage(imagePath) {
const imageBuffer = fs.readFileSync(imagePath);
const tfImage = tf.node.decodeImage(imageBuffer);
Expand All @@ -59,14 +86,24 @@ class StyleTransfer {

async function main() {
const styleTransfer = new StyleTransfer();
//images to be stylized

// Single style transfer
const contentPath = './skull.jpg';
const stylePath = './flowers.jpg';
const stylePath = './paint.jpg';
const outputPath = './stylized_image.jpg';
const styleRatio = 0.8;
const styleRatio = 0.95;

await styleTransfer.stylizeImage(contentPath, stylePath, outputPath, styleRatio);
console.log('Stylized image saved to:', outputPath);
console.log('Single style transfer: Stylized image saved to:', outputPath);

// Combined style transfer
const style1Path = './paint.jpg';
const style2Path = './flowers.jpg';
const combinedOutputPath = './combined_stylized_image.jpg';
const combinedStyleRatio = 0.5; // Equal mix of both styles

await styleTransfer.combineStyles(contentPath, style1Path, style2Path, combinedOutputPath, combinedStyleRatio);
console.log('Combined style transfer: Stylized image saved to:', combinedOutputPath);
}

main().catch(console.error);
Binary file added combined_stylized_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paint.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified stylized_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 1b9cf7c

Please sign in to comment.