Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a Skin Condition classification task #678

Merged
merged 20 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d78e049
Integrate skin disease classification task
JulienVig May 27, 2024
19612e0
Move mobilenet model definitions to discojs/src/models folder
JulienVig May 28, 2024
37a3849
Clean dataset_builder and make dataset shuffling true by default
JulienVig May 28, 2024
8ec3439
Improve skin_mnist model definition
JulienVig May 28, 2024
e6ef9de
Make image data selection mode default depend on the number of labels
JulienVig May 28, 2024
3ba6149
Use pretrained mobilenet for transfer learning
JulienVig May 29, 2024
f08a308
Save transfer learning implementation
JulienVig May 29, 2024
8a720e0
Fix validator getLabel for multi-class tasks
JulienVig May 29, 2024
3d41da3
Make test page display image label names rather than label indices
JulienVig May 30, 2024
351aa65
Update skin mnist model architecture
JulienVig May 30, 2024
c649016
Add a CSV creation example notebook to docs/
JulienVig May 30, 2024
323865e
Link CSV creation notebook into the webapp data selection UI
JulienVig May 30, 2024
06cbf2b
Add link to resulting sample dataset
JulienVig May 30, 2024
039d4f7
Remove mobilenet weights
JulienVig May 30, 2024
8a06fcc
Change skin mnist into skin condition task
JulienVig May 30, 2024
4002797
Add the preprocessing notebook scin_dataset.ipynb to the examples
JulienVig May 30, 2024
d2e68c5
Make webapp GROUP data selection mode default if there are less than …
JulienVig May 30, 2024
76b35fc
Reword data format description
JulienVig May 30, 2024
fda5a4f
Handle label name overflow on the test page
JulienVig May 30, 2024
4c6a137
Link resulting sample dataset
JulienVig May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# dependencies
/node_modules/
# disco models
**/models
# model.json files
server/models
cli/models
# built
dist/

Expand Down
71 changes: 22 additions & 49 deletions discojs/src/dataset/dataset_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,11 @@ export class DatasetBuilder<Source> {
/**
* The buffer of unlabelled file sources.
*/
private _sources: Source[]
private _unlabeledSources: Source[]
/**
* The buffer of labelled file sources.
*/
private labelledSources: Map<string, Source[]>
/**
* Whether a dataset was already produced.
*/
// TODO useless, responsibility on callers
private _built: boolean
private _labeledSources: Map<string, Source[]>

constructor (
/**
Expand All @@ -34,9 +29,9 @@ export class DatasetBuilder<Source> {
*/
public readonly task: Task
) {
this._sources = []
this.labelledSources = Map()
this._built = false
this._unlabeledSources = []
// Map from label to sources
this._labeledSources = Map()
}

/**
Expand All @@ -46,17 +41,14 @@ export class DatasetBuilder<Source> {
* @param label The file sources label
*/
addFiles (sources: Source[], label?: string): void {
if (this.built) {
this.resetBuiltState()
}
if (label === undefined) {
this._sources = this._sources.concat(sources)
this._unlabeledSources = this._unlabeledSources.concat(sources)
} else {
const currentSources = this.labelledSources.get(label)
const currentSources = this._labeledSources.get(label)
if (currentSources === undefined) {
this.labelledSources = this.labelledSources.set(label, sources)
this._labeledSources = this._labeledSources.set(label, sources)
} else {
this.labelledSources = this.labelledSources.set(label, currentSources.concat(sources))
this._labeledSources = this._labeledSources.set(label, currentSources.concat(sources))
}
}
}
Expand All @@ -67,28 +59,19 @@ export class DatasetBuilder<Source> {
* @param label The file sources label
*/
clearFiles (label?: string): void {
if (this.built) {
this.resetBuiltState()
}
if (label === undefined) {
this._sources = []
this._unlabeledSources = []
} else {
this.labelledSources = this.labelledSources.delete(label)
this._labeledSources = this._labeledSources.delete(label)
}
}

// If files are added or removed, then this should be called since the latest
// version of the dataset_builder has not yet been built.
private resetBuiltState (): void {
this._built = false
}

private getLabels (): string[] {
// We need to duplicate the labels as we need one for each source.
// Say for label A we have sources [img1, img2, img3], then we
// need labels [A, A, A].
let labels: string[][] = []
this.labelledSources.forEach((sources, label) => {
this._labeledSources.forEach((sources, label) => {
const sourcesLabels = Array.from({ length: sources.length }, (_) => label)
labels = labels.concat(sourcesLabels)
})
Expand All @@ -97,56 +80,46 @@ export class DatasetBuilder<Source> {

async build (config?: DataConfig): Promise<DataSplit> {
// Require that at least one source collection is non-empty, but not both
if ((this._sources.length > 0) === (this.labelledSources.size > 0)) {
throw new Error('Please provide dataset input files') // This error message is parsed in DatasetInput.vue
if (this._unlabeledSources.length + this._labeledSources.size === 0) {
throw new Error('No input files connected') // This error message is parsed in Trainer.vue
}

let dataTuple: DataSplit
if (this._sources.length > 0) {
if (this._unlabeledSources.length > 0) {
let defaultConfig: DataConfig = {}

if (config?.inference === true) {
// Inferring model, no labels needed
defaultConfig = {
features: this.task.trainingInformation.inputColumns,
shuffle: false
shuffle: true
}
} else {
// Labels are contained in the given sources
defaultConfig = {
features: this.task.trainingInformation.inputColumns,
labels: this.task.trainingInformation.outputColumns,
shuffle: false
shuffle: true
}
}

dataTuple = await this.dataLoader.loadAll(this._sources, { ...defaultConfig, ...config })
dataTuple = await this.dataLoader.loadAll(this._unlabeledSources, { ...defaultConfig, ...config })
} else {
// Labels are inferred from the file selection boxes
const defaultConfig = {
labels: this.getLabels(),
shuffle: false
shuffle: true
}
const sources = this.labelledSources.valueSeq().toArray().flat()
const sources = this._labeledSources.valueSeq().toArray().flat()
dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config })
}
// TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
this._built = true
return dataTuple
}

/**
* Whether the dataset builder has already been consumed to produce a dataset.
*/
get built (): boolean {
return this._built
}

get size (): number {
return Math.max(this._sources.length, this.labelledSources.size)
return Math.max(this._unlabeledSources.length, this._labeledSources.size)
}

get sources (): Source[] {
return this._sources.length > 0 ? this._sources : this.labelledSources.valueSeq().toArray().flat()
return this._unlabeledSources.length > 0 ? this._unlabeledSources : this._labeledSources.valueSeq().toArray().flat()
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../../index.js'
import { data, models } from '../../index.js'
import type { Model, Task, TaskProvider } from '../index.js'
import { data, models } from '../index.js'

import baseModel from './model.js'
import baseModel from '../models/mobileNet_v1_025_224.js'

export const cifar10: TaskProvider = {
getTask (): Task {
Expand Down
5 changes: 3 additions & 2 deletions discojs/src/default_tasks/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export { cifar10 } from './cifar10/index.js'
export { cifar10 } from './cifar10.js'
export { lusCovid } from './lus_covid.js'
export { skinCondition } from './skin_condition.js'
export { mnist } from './mnist.js'
export { simpleFace } from './simple_face/index.js'
export { simpleFace } from './simple_face.js'
export { titanic } from './titanic.js'
export { wikitext } from './wikitext.js'
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/lus_covid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export const lusCovid: TaskProvider = {
taskTitle: 'COVID Lung Ultrasound',
summary: {
preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
overview: "Dont have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
overview: "Don't have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
},
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../../index.js'
import { data, models } from '../../index.js'

import baseModel from './model.js'
import type { Model, Task, TaskProvider } from '../index.js'
import { data, models } from '../index.js'
import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js'

export const simpleFace: TaskProvider = {
getTask (): Task {
Expand All @@ -16,7 +15,7 @@ export const simpleFace: TaskProvider = {
overview: 'Simple face is a small subset of face_task from Kaggle'
},
dataFormatInformation: '',
dataExampleText: 'Below you find an example',
dataExampleText: 'Below you can find an example',
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
},
trainingInformation: {
Expand Down
95 changes: 95 additions & 0 deletions discojs/src/default_tasks/skin_condition.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../index.js'
import { data, models } from '../index.js'

const IMAGE_SIZE = 128
const LABELS = ['Eczema', 'Allergic Contact Dermatitis', 'Urticaria']

export const skinCondition: TaskProvider = {
getTask (): Task {
return {
id: 'skin_condition',
displayInformation: {
taskTitle: 'Skin Condition Classification',
summary: {
preview: "Identify common skin conditions from volunteer image contributions. You can find a sample dataset of 400 images <a class='underline text-primary-dark dark:text-primary-light' href='https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'>here</a> or see the full <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN dataset</a>. You can find how to download and preprocess the dataset <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/blob/develop/docs/examples/scin_dataset.ipynb'>in this notebook</a>.",
overview: "The <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN (Skin Condition Image Network) open access dataset</a> aims to supplement publicly available dermatology datasets from health system sources with representative images from internet users. To this end, the SCIN dataset was collected from Google Search users in the United States through a voluntary, consented image donation application. The SCIN dataset is intended for health education and research, and to increase the diversity of dermatology images available for public use. The SCIN dataset contains 5,000+ volunteer contributions (10,000+ images) of common dermatology conditions. Contributions include Images, self-reported demographic, history, and symptom information, and self-reported Fitzpatrick skin type (sFST). In addition, dermatologist labels of the skin condition are provided for each contribution. You can find more information on the dataset and classification task <a class='underline text-primary-dark dark:text-primary-light' href='https://arxiv.org/abs/2402.18545'>here</a>."
},
dataFormatInformation: "There are hundreds of skin condition labels in the SCIN dataset. For the sake of simplicity, we only include the 3 most common conditions in the sample dataset: 'Eczema', 'Allergic Contact Dermatitis' and 'Urticaria'. Therefore, each image is expected to be labeled with one of these three categories.",
sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'
},
trainingInformation: {
modelID: 'skin-condition-model',
epochs: 10,
roundDuration: 2,
validationSplit: 0.3,
batchSize: 8,
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
dataType: 'image',
IMAGE_H: IMAGE_SIZE,
IMAGE_W: IMAGE_SIZE,
LABEL_LIST: LABELS,
scheme: 'federated',
noiseScale: undefined,
clippingRadius: undefined
}
}
},

async getModel(): Promise<Model> {
const imageChannels = 3
const numOutputClasses = LABELS.length

const model = tf.sequential()

model.add(
tf.layers.conv2d({
inputShape: [IMAGE_SIZE, IMAGE_SIZE, imageChannels],
filters: 8,
kernelSize: 3,
strides: 1,
kernelInitializer: 'varianceScaling',
activation: 'relu'
})
)
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2]}))
model.add(tf.layers.dropout({ rate: 0.2 }))

const convFilters = [16, 32, 64, 128]
for (const filters of convFilters) {
model.add(
tf.layers.conv2d({
filters: filters,
kernelSize: 3,
strides: 1,
kernelInitializer: 'varianceScaling',
activation: 'relu'
})
)

model.add(tf.layers.maxPooling2d({ poolSize: [2, 2]}))
model.add(tf.layers.dropout({ rate: 0.2 }))
}

model.add(tf.layers.flatten())
model.add(tf.layers.dense({
units: 64,
kernelInitializer: 'varianceScaling',
activation: 'relu',
}))

model.add(tf.layers.dense({
units: numOutputClasses,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
}))

model.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
})
return Promise.resolve(new models.TFJS(model))
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Source: https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json
// This model was converted using the tensorflow.js converter
export default {
format: "layers-model",
generatedBy: "keras v2.6.0",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Source: https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json
export default {
modelTopology: {
keras_version: "2.1.4",
Expand Down
21 changes: 10 additions & 11 deletions discojs/src/validation/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,22 @@ export class Validator {
}
}

private async getLabel (ys: tf.Tensor): Promise<Float32Array | Int32Array | Uint8Array> {
switch (ys.shape[1]) {
case 1:
return await ys.greaterEqual(tf.scalar(0.5)).data()
case 2:
return await ys.argMax(1).data()
default:
throw new Error(`unable to reduce tensor of shape: ${ys.shape.toString()}`)
private async getLabel(ys: tf.Tensor): Promise<Float32Array | Int32Array | Uint8Array> {
// Binary classification
if (ys.shape[1] == 1) {
return await ys.greaterEqual(tf.scalar(0.5)).data()
// Multi-class classification
} else {
return await ys.argMax(-1).data()
}
// Multi-label classification is not supported
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
}

async assess (data: data.Data, useConfusionMatrix: boolean = false): Promise<Array<{ groundTruth: number, pred: number, features: Features }>> {
const batchSize = this.task.trainingInformation?.batchSize
if (batchSize === undefined) {
throw new TypeError('Batch size is undefined')
}

const model = await this.getModel()

let features: Features[] = []
Expand All @@ -52,7 +51,7 @@ export class Validator {
const xs = e.xs as tf.Tensor
const ys = await this.getLabel(e.ys as tf.Tensor)
const pred = await this.getLabel(await model.predict(xs))

const currentFeatures = await xs.array()
if (Array.isArray(currentFeatures)) {
features = features.concat(currentFeatures)
Expand All @@ -70,7 +69,7 @@ export class Validator {
throw new Error('Input data is missing a feature or the label')
}
}).toArray()).flat()

this.logger.success(`Obtained validation accuracy of ${this.accuracy}`)
this.logger.success(`Visited ${this.visitedSamples} samples`)

Expand Down
4 changes: 4 additions & 0 deletions docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ You can run the custom task example with:
cd docs/examples
npm run custom_task # compiles TypeScript and runs custom_task.ts
```

### Creating a CSV file to connect a dataset in DISCO

DISCO allows connecting data through a CSV file that maps filenames to labels. The python notebook `dataset_csv_creation.ipynb` shows how to create such a CSV and the `scin_dataset.ipynb` shows how to download the [SCIN dataset](https://github.com/google-research-datasets/scin/tree/main?tab=readme-ov-file) and how to preprocess it in a format accepted by DISCO.
Loading