Skip to content

Commit

Permalink
Extract report data from the PDF document using AI (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrivals committed Oct 7, 2024
1 parent 4c07c7d commit e691b7d
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 123 deletions.
87 changes: 87 additions & 0 deletions server/services/aiService/aiService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import OpenAI from 'openai';
import { translateToEnglishPrompt } from '../../prompts/translateToEnglish.prompt';
import config from '../../utils/config';

export interface ReferenceEmbedding<T> {
reference: T;
embedding: number[];
}

const openai = new OpenAI({
apiKey: config.apis.openai.apiKey,
});

const generateEmbeddings = async (labels: string[]): Promise<number[][]> => {
const response = await openai.embeddings.create({
model: 'text-embedding-ada-002',
input: labels,
});

return response.data.map(
(embedding: { embedding: number[] }) => embedding.embedding
);
};

const cosineSimilarity = (vecA: number[], vecB: number[]): number => {
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0);
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0));
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0));

return dotProduct / (magnitudeA * magnitudeB);
};

const normalizeText = (text: string) => {
return text
.normalize('NFD')
.replace(/[\u0300-\u036f]/g, '') // Supprimer les accents
.toLowerCase();
};

const translateToEnglish = async (text: string) => {
const completion = await openai.chat.completions.create({
model: 'gpt-4o-mini',
messages: translateToEnglishPrompt(text),
});

return completion.choices[0].message.content;
};

async function resolveReferenceWithEmbeddings<T>(
text: string,
referenceEmbeddings: ReferenceEmbedding<T>[]
): Promise<T | undefined> {
console.log('resolveReferenceWithEmbeddings for text', text);

const normalizedText = normalizeText(text);
const translatedText = await translateToEnglish(normalizedText);
const textEmbeddings = (
await generateEmbeddings([translatedText as string])
)[0];

let bestMatch: {
referenceEmbedding?: ReferenceEmbedding<T>;
similarity: number;
} = {
similarity: -Infinity,
};

for (const referenceEmbedding of referenceEmbeddings) {
const similarity = cosineSimilarity(
textEmbeddings,
referenceEmbedding.embedding
);

if (similarity > bestMatch.similarity) {
bestMatch = {
referenceEmbedding,
similarity,
};
}
}

console.log('bestMatch', bestMatch.referenceEmbedding);

return bestMatch.referenceEmbedding?.reference;
}

export { generateEmbeddings, cosineSimilarity, resolveReferenceWithEmbeddings };
227 changes: 104 additions & 123 deletions server/services/analysisService/analysisService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,68 @@ import OpenAI from 'openai';
import { zodResponseFormat } from 'openai/helpers/zod';
import z from 'zod';
import { OptionalBoolean } from '../../../shared/referential/OptionnalBoolean';
import { Analyte } from '../../../shared/referential/Residue/Analyte';
import { AnalyteLabels } from '../../../shared/referential/Residue/AnalyteLabels';
import { ComplexResidue } from '../../../shared/referential/Residue/ComplexResidue';
import { ComplexResidueAnalytes } from '../../../shared/referential/Residue/ComplexResidueAnalytes';
import { ComplexResidueLabels } from '../../../shared/referential/Residue/ComplexResidueLabels';
import { SimpleResidue } from '../../../shared/referential/Residue/SimpleResidue';
import { SimpleResidueLabels } from '../../../shared/referential/Residue/SimpleResidueLabels';
import { ResidueCompliance } from '../../../shared/schema/Analysis/Residue/ResidueCompliance';
import { ResidueKind } from '../../../shared/schema/Analysis/Residue/ResidueKind';
import { ResultKind } from '../../../shared/schema/Analysis/Residue/ResultKind';
import {
ResultKind,
ResultKindLabels,
} from '../../../shared/schema/Analysis/Residue/ResultKind';
import { extractAnalysisFromReportPrompt } from '../../prompts/extractAnalysisFromReport.prompt';
import { translateToEnglishPrompt } from '../../prompts/translateToEnglish.prompt';
import config from '../../utils/config';
import {
generateEmbeddings,
ReferenceEmbedding,
resolveReferenceWithEmbeddings,
} from '../aiService/aiService';
import documentService from '../documentService/documentService';
const openai = new OpenAI({
apiKey: config.apis.openai.apiKey,
});

const AnalyteExtraction = z.object({
analyteNumber: z.number().int(),
label: z.string(),
reference: z.string().nullish(),
resultKind: ResultKind.describe(ResultKindLabels.toString()),
result: z.number().nullish(),
});

const ResidueExtraction = z.object({
label: z.string(),
reference: z.string().nullish(),
residueNumber: z.number().int(),
kind: ResidueKind,
resultKind: ResultKind.nullish(),
result: z.number().nullish(),
LMR: z.number().nullish().describe('LMR value. Not the LQ value'),
resultKind: ResultKind.nullish().describe(ResultKindLabels.toString()),
result: z.number().nullish().describe('Result when resultKind is Q'),
LMR: z
.number()
.nullish()
.describe(
'LMR when resultKind is Q. Not the LQ value. This can be defined by a max value called specifications'
),
resultHigherThanArfd: OptionalBoolean,
notesOnResult: z.string().nullish(),
substanceApproved: OptionalBoolean,
substanceAuthorised: OptionalBoolean,
pollutionRisk: OptionalBoolean.nullish(),
notesOnPollutionRisk: z.string().nullish(),
compliance: ResidueCompliance,
analytes: z
.array(AnalyteExtraction)
.nullish()
.describe(
'Only for complex residue. Contains the list of the sub residues.'
),
});

const AnalyseExtraction = z.object({
const AnalysisExtraction = z.object({
kind: z.enum(['Mono', 'Multi']).optional(),
residues: z
.array(ResidueExtraction)
Expand All @@ -43,137 +75,86 @@ const AnalyseExtraction = z.object({
notesOnCompliance: z.string().nullable().optional(),
});

type AnalyseExtraction = z.infer<typeof AnalyseExtraction>;

interface ResidueEmbedding {
code: SimpleResidue;
embedding: number[];
}

// const retrieveReferences = async (
// analysisExtraction: AnalyseExtraction
// ): Promise<AnalyseExtraction> => {
// const prompt = `
// Tu es un assistant qui doit enrichir les labels dans un flux JSON avec les codes associés en te basant sur ${Object.entries(
// SimpleResidueLabels
// )}.
// Retrouve la ligne qui a le label le plus proche du texte sans accent traduit en anglais.
// Tu dois retourner le JSON modifié.
//
// Pour {"label": "Chlorméquat (+ sels)"} retourne {"label": "Chlormequat (sum of chlormequat and its salts, expressed as chlormequat-chloride)" ,"code":"RF-00005727-PAR"}.
//
// Réponds uniquement en JSON valide sans les balises \`\`\`json ou \`\`\`.
// `;
//
// const completion = await openai.chat.completions.create({
// model: 'gpt-4o-mini',
// messages: [
// {
// role: 'system',
// content: prompt,
// },
// { role: 'user', content: JSON.stringify(analysisExtraction) },
// ],
// });
//
// console.log(analysisExtraction, completion.choices[0].message.content);
//
// return JSON.parse(
// completion.choices[0].message.content ?? '{}'
// ) as AnalyseExtraction;
// };

// Fonction pour générer des embeddings pour vos labels
const generateEmbeddings = async (labels: string[]): Promise<number[][]> => {
const response = await openai.embeddings.create({
model: 'text-embedding-ada-002',
input: labels,
});
type AnalyseExtraction = z.infer<typeof AnalysisExtraction>;
type ResidueExtraction = z.infer<typeof ResidueExtraction>;
type AnalyteExtraction = z.infer<typeof AnalyteExtraction>;

return response.data.map(
(embedding: { embedding: number[] }) => embedding.embedding
const retrieveAnalytesReferences = async (
residue: ResidueExtraction,
reference: ComplexResidue
): Promise<ResidueExtraction> => {
const complexResidueAnalyteLabels = ComplexResidueAnalytes[reference].map(
(analyte) => AnalyteLabels[analyte]
);
};

const cosineSimilarity = (vecA: number[], vecB: number[]): number => {
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0);
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0));
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0));

return dotProduct / (magnitudeA * magnitudeB);
};

const resolveReference = async (
residueLabel: string,
residueEmbeddings: ResidueEmbedding[]
): Promise<SimpleResidue | undefined> => {
console.log('resolveReference RESIDUE LABEL', residueLabel);

const normalizedText = normalizeText(residueLabel);
const translatedText = await translateToEnglish(normalizedText);
const residueEmbedding = (
await generateEmbeddings([translatedText as string])
)[0];

let bestMatch = { code: '', similarity: -Infinity };

for (const reference of residueEmbeddings) {
const similarity = cosineSimilarity(residueEmbedding, reference.embedding);
const analytesEmbeddings: ReferenceEmbedding<Analyte>[] =
await generateEmbeddings(complexResidueAnalyteLabels).then((embeddings) =>
embeddings.map((embedding, index) => ({
reference: ComplexResidueAnalytes[reference][index],
embedding,
}))
);

const newAnalytes = await Promise.all(
(residue.analytes ?? []).map(async (analyte) => {
const reference = await resolveReferenceWithEmbeddings<Analyte>(
analyte.label,
analytesEmbeddings
);

if (similarity > bestMatch.similarity) {
bestMatch = {
code: reference.code,
similarity,
return {
...analyte,
label: reference ? AnalyteLabels[reference as Analyte] : analyte.label,
reference,
};
}
}

console.log('BEST MATCH', bestMatch);

return bestMatch.code as SimpleResidue | undefined;
};

const normalizeText = (text: string) => {
return text
.normalize('NFD')
.replace(/[\u0300-\u036f]/g, '') // Supprimer les accents
.toLowerCase();
};

const translateToEnglish = async (text: string) => {
const completion = await openai.chat.completions.create({
model: 'gpt-4o-mini',
messages: translateToEnglishPrompt(text),
});
})
);

return completion.choices[0].message.content;
return {
...residue,
analytes: newAnalytes,
};
};

const retrieveReferencesWithEmbeddings = async (
const retrieveResiduesReferences = async (
analysisExtraction: AnalyseExtraction
): Promise<AnalyseExtraction> => {
// Génération des embeddings pour les résidus (à faire une seule fois et stocker)
const embeddings = await generateEmbeddings(
Object.values(SimpleResidueLabels)
);

const residueEmbeddings: ResidueEmbedding[] = Object.keys(
SimpleResidueLabels
).map((code, index) => ({
code: code as SimpleResidue,
embedding: embeddings[index],
}));
const simpleResidueEmbeddings: ReferenceEmbedding<SimpleResidue>[] =
await generateEmbeddings(Object.values(SimpleResidueLabels)).then(
(embeddings) =>
embeddings.map((embedding, index) => ({
reference: Object.keys(SimpleResidueLabels)[index] as SimpleResidue,
embedding,
}))
);
const complexResidueEmbeddings: ReferenceEmbedding<ComplexResidue>[] =
await generateEmbeddings(Object.values(ComplexResidueLabels)).then(
(embeddings) =>
embeddings.map((embedding, index) => ({
reference: Object.keys(ComplexResidueLabels)[index] as ComplexResidue,
embedding,
}))
);

const newResidues = await Promise.all(
(analysisExtraction.residues ?? []).map(async (residue) => {
const reference = await resolveReference(
residue.label,
residueEmbeddings
);
const reference = await resolveReferenceWithEmbeddings(residue.label, [
...simpleResidueEmbeddings,
...complexResidueEmbeddings,
]);

const residueWithAnalyte = ComplexResidue.safeParse(reference).success
? await retrieveAnalytesReferences(residue, reference as ComplexResidue)
: residue;

return {
...residue,
label: reference ? SimpleResidueLabels[reference] : residue.label,
...residueWithAnalyte,
label: SimpleResidue.safeParse(reference).success
? SimpleResidueLabels[reference as SimpleResidue]
: ComplexResidue.safeParse(reference).success
? ComplexResidueLabels[reference as ComplexResidue]
: (reference as string),
reference,
};
})
Expand All @@ -193,7 +174,7 @@ export const extractFromReport = async (
const completion = await openai.chat.completions.create({
model: 'gpt-4o-mini-2024-07-18',
messages: extractAnalysisFromReportPrompt(content),
response_format: zodResponseFormat(AnalyseExtraction, 'analysis'),
response_format: zodResponseFormat(AnalysisExtraction, 'analysis'),
});

console.log(completion.choices[0].message.content);
Expand All @@ -202,5 +183,5 @@ export const extractFromReport = async (
completion.choices[0].message.content ?? '{}'
) as AnalyseExtraction;

return retrieveReferencesWithEmbeddings(analyseExtraction);
return retrieveResiduesReferences(analyseExtraction);
};

0 comments on commit e691b7d

Please sign in to comment.