From e691b7d5317b6af77fb159ba8e7ec6542c4b0f4c Mon Sep 17 00:00:00 2001 From: jrivals Date: Mon, 7 Oct 2024 14:01:09 +0200 Subject: [PATCH] Extract report data from the PDF document using AI (WIP) --- server/services/aiService/aiService.ts | 87 +++++++ .../analysisService/analysisService.ts | 227 ++++++++---------- 2 files changed, 191 insertions(+), 123 deletions(-) create mode 100644 server/services/aiService/aiService.ts diff --git a/server/services/aiService/aiService.ts b/server/services/aiService/aiService.ts new file mode 100644 index 00000000..b6859ca6 --- /dev/null +++ b/server/services/aiService/aiService.ts @@ -0,0 +1,87 @@ +import OpenAI from 'openai'; +import { translateToEnglishPrompt } from '../../prompts/translateToEnglish.prompt'; +import config from '../../utils/config'; + +export interface ReferenceEmbedding { + reference: T; + embedding: number[]; +} + +const openai = new OpenAI({ + apiKey: config.apis.openai.apiKey, +}); + +const generateEmbeddings = async (labels: string[]): Promise => { + const response = await openai.embeddings.create({ + model: 'text-embedding-ada-002', + input: labels, + }); + + return response.data.map( + (embedding: { embedding: number[] }) => embedding.embedding + ); +}; + +const cosineSimilarity = (vecA: number[], vecB: number[]): number => { + const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0); + const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0)); + const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0)); + + return dotProduct / (magnitudeA * magnitudeB); +}; + +const normalizeText = (text: string) => { + return text + .normalize('NFD') + .replace(/[\u0300-\u036f]/g, '') // Supprimer les accents + .toLowerCase(); +}; + +const translateToEnglish = async (text: string) => { + const completion = await openai.chat.completions.create({ + model: 'gpt-4o-mini', + messages: translateToEnglishPrompt(text), + }); + + return completion.choices[0].message.content; +}; + +async function resolveReferenceWithEmbeddings( + text: string, + referenceEmbeddings: ReferenceEmbedding[] +): Promise { + console.log('resolveReferenceWithEmbeddings for text', text); + + const normalizedText = normalizeText(text); + const translatedText = await translateToEnglish(normalizedText); + const textEmbeddings = ( + await generateEmbeddings([translatedText as string]) + )[0]; + + let bestMatch: { + referenceEmbedding?: ReferenceEmbedding; + similarity: number; + } = { + similarity: -Infinity, + }; + + for (const referenceEmbedding of referenceEmbeddings) { + const similarity = cosineSimilarity( + textEmbeddings, + referenceEmbedding.embedding + ); + + if (similarity > bestMatch.similarity) { + bestMatch = { + referenceEmbedding, + similarity, + }; + } + } + + console.log('bestMatch', bestMatch.referenceEmbedding); + + return bestMatch.referenceEmbedding?.reference; +} + +export { generateEmbeddings, cosineSimilarity, resolveReferenceWithEmbeddings }; diff --git a/server/services/analysisService/analysisService.ts b/server/services/analysisService/analysisService.ts index c3c61e13..8fb71ef0 100644 --- a/server/services/analysisService/analysisService.ts +++ b/server/services/analysisService/analysisService.ts @@ -2,26 +2,52 @@ import OpenAI from 'openai'; import { zodResponseFormat } from 'openai/helpers/zod'; import z from 'zod'; import { OptionalBoolean } from '../../../shared/referential/OptionnalBoolean'; +import { Analyte } from '../../../shared/referential/Residue/Analyte'; +import { AnalyteLabels } from '../../../shared/referential/Residue/AnalyteLabels'; +import { ComplexResidue } from '../../../shared/referential/Residue/ComplexResidue'; +import { ComplexResidueAnalytes } from '../../../shared/referential/Residue/ComplexResidueAnalytes'; +import { ComplexResidueLabels } from '../../../shared/referential/Residue/ComplexResidueLabels'; import { SimpleResidue } from '../../../shared/referential/Residue/SimpleResidue'; import { SimpleResidueLabels } from '../../../shared/referential/Residue/SimpleResidueLabels'; import { ResidueCompliance } from '../../../shared/schema/Analysis/Residue/ResidueCompliance'; import { ResidueKind } from '../../../shared/schema/Analysis/Residue/ResidueKind'; -import { ResultKind } from '../../../shared/schema/Analysis/Residue/ResultKind'; +import { + ResultKind, + ResultKindLabels, +} from '../../../shared/schema/Analysis/Residue/ResultKind'; import { extractAnalysisFromReportPrompt } from '../../prompts/extractAnalysisFromReport.prompt'; -import { translateToEnglishPrompt } from '../../prompts/translateToEnglish.prompt'; import config from '../../utils/config'; +import { + generateEmbeddings, + ReferenceEmbedding, + resolveReferenceWithEmbeddings, +} from '../aiService/aiService'; import documentService from '../documentService/documentService'; const openai = new OpenAI({ apiKey: config.apis.openai.apiKey, }); +const AnalyteExtraction = z.object({ + analyteNumber: z.number().int(), + label: z.string(), + reference: z.string().nullish(), + resultKind: ResultKind.describe(ResultKindLabels.toString()), + result: z.number().nullish(), +}); + const ResidueExtraction = z.object({ label: z.string(), + reference: z.string().nullish(), residueNumber: z.number().int(), kind: ResidueKind, - resultKind: ResultKind.nullish(), - result: z.number().nullish(), - LMR: z.number().nullish().describe('LMR value. Not the LQ value'), + resultKind: ResultKind.nullish().describe(ResultKindLabels.toString()), + result: z.number().nullish().describe('Result when resultKind is Q'), + LMR: z + .number() + .nullish() + .describe( + 'LMR when resultKind is Q. Not the LQ value. This can be defined by a max value called specifications' + ), resultHigherThanArfd: OptionalBoolean, notesOnResult: z.string().nullish(), substanceApproved: OptionalBoolean, @@ -29,9 +55,15 @@ const ResidueExtraction = z.object({ pollutionRisk: OptionalBoolean.nullish(), notesOnPollutionRisk: z.string().nullish(), compliance: ResidueCompliance, + analytes: z + .array(AnalyteExtraction) + .nullish() + .describe( + 'Only for complex residue. Contains the list of the sub residues.' + ), }); -const AnalyseExtraction = z.object({ +const AnalysisExtraction = z.object({ kind: z.enum(['Mono', 'Multi']).optional(), residues: z .array(ResidueExtraction) @@ -43,137 +75,86 @@ const AnalyseExtraction = z.object({ notesOnCompliance: z.string().nullable().optional(), }); -type AnalyseExtraction = z.infer; - -interface ResidueEmbedding { - code: SimpleResidue; - embedding: number[]; -} - -// const retrieveReferences = async ( -// analysisExtraction: AnalyseExtraction -// ): Promise => { -// const prompt = ` -// Tu es un assistant qui doit enrichir les labels dans un flux JSON avec les codes associés en te basant sur ${Object.entries( -// SimpleResidueLabels -// )}. -// Retrouve la ligne qui a le label le plus proche du texte sans accent traduit en anglais. -// Tu dois retourner le JSON modifié. -// -// Pour {"label": "Chlorméquat (+ sels)"} retourne {"label": "Chlormequat (sum of chlormequat and its salts, expressed as chlormequat-chloride)" ,"code":"RF-00005727-PAR"}. -// -// Réponds uniquement en JSON valide sans les balises \`\`\`json ou \`\`\`. -// `; -// -// const completion = await openai.chat.completions.create({ -// model: 'gpt-4o-mini', -// messages: [ -// { -// role: 'system', -// content: prompt, -// }, -// { role: 'user', content: JSON.stringify(analysisExtraction) }, -// ], -// }); -// -// console.log(analysisExtraction, completion.choices[0].message.content); -// -// return JSON.parse( -// completion.choices[0].message.content ?? '{}' -// ) as AnalyseExtraction; -// }; - -// Fonction pour générer des embeddings pour vos labels -const generateEmbeddings = async (labels: string[]): Promise => { - const response = await openai.embeddings.create({ - model: 'text-embedding-ada-002', - input: labels, - }); +type AnalyseExtraction = z.infer; +type ResidueExtraction = z.infer; +type AnalyteExtraction = z.infer; - return response.data.map( - (embedding: { embedding: number[] }) => embedding.embedding +const retrieveAnalytesReferences = async ( + residue: ResidueExtraction, + reference: ComplexResidue +): Promise => { + const complexResidueAnalyteLabels = ComplexResidueAnalytes[reference].map( + (analyte) => AnalyteLabels[analyte] ); -}; - -const cosineSimilarity = (vecA: number[], vecB: number[]): number => { - const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0); - const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0)); - const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0)); - - return dotProduct / (magnitudeA * magnitudeB); -}; -const resolveReference = async ( - residueLabel: string, - residueEmbeddings: ResidueEmbedding[] -): Promise => { - console.log('resolveReference RESIDUE LABEL', residueLabel); - - const normalizedText = normalizeText(residueLabel); - const translatedText = await translateToEnglish(normalizedText); - const residueEmbedding = ( - await generateEmbeddings([translatedText as string]) - )[0]; - - let bestMatch = { code: '', similarity: -Infinity }; - - for (const reference of residueEmbeddings) { - const similarity = cosineSimilarity(residueEmbedding, reference.embedding); + const analytesEmbeddings: ReferenceEmbedding[] = + await generateEmbeddings(complexResidueAnalyteLabels).then((embeddings) => + embeddings.map((embedding, index) => ({ + reference: ComplexResidueAnalytes[reference][index], + embedding, + })) + ); + + const newAnalytes = await Promise.all( + (residue.analytes ?? []).map(async (analyte) => { + const reference = await resolveReferenceWithEmbeddings( + analyte.label, + analytesEmbeddings + ); - if (similarity > bestMatch.similarity) { - bestMatch = { - code: reference.code, - similarity, + return { + ...analyte, + label: reference ? AnalyteLabels[reference as Analyte] : analyte.label, + reference, }; - } - } - - console.log('BEST MATCH', bestMatch); - - return bestMatch.code as SimpleResidue | undefined; -}; - -const normalizeText = (text: string) => { - return text - .normalize('NFD') - .replace(/[\u0300-\u036f]/g, '') // Supprimer les accents - .toLowerCase(); -}; - -const translateToEnglish = async (text: string) => { - const completion = await openai.chat.completions.create({ - model: 'gpt-4o-mini', - messages: translateToEnglishPrompt(text), - }); + }) + ); - return completion.choices[0].message.content; + return { + ...residue, + analytes: newAnalytes, + }; }; -const retrieveReferencesWithEmbeddings = async ( +const retrieveResiduesReferences = async ( analysisExtraction: AnalyseExtraction ): Promise => { // Génération des embeddings pour les résidus (à faire une seule fois et stocker) - const embeddings = await generateEmbeddings( - Object.values(SimpleResidueLabels) - ); - - const residueEmbeddings: ResidueEmbedding[] = Object.keys( - SimpleResidueLabels - ).map((code, index) => ({ - code: code as SimpleResidue, - embedding: embeddings[index], - })); + const simpleResidueEmbeddings: ReferenceEmbedding[] = + await generateEmbeddings(Object.values(SimpleResidueLabels)).then( + (embeddings) => + embeddings.map((embedding, index) => ({ + reference: Object.keys(SimpleResidueLabels)[index] as SimpleResidue, + embedding, + })) + ); + const complexResidueEmbeddings: ReferenceEmbedding[] = + await generateEmbeddings(Object.values(ComplexResidueLabels)).then( + (embeddings) => + embeddings.map((embedding, index) => ({ + reference: Object.keys(ComplexResidueLabels)[index] as ComplexResidue, + embedding, + })) + ); const newResidues = await Promise.all( (analysisExtraction.residues ?? []).map(async (residue) => { - const reference = await resolveReference( - residue.label, - residueEmbeddings - ); + const reference = await resolveReferenceWithEmbeddings(residue.label, [ + ...simpleResidueEmbeddings, + ...complexResidueEmbeddings, + ]); + + const residueWithAnalyte = ComplexResidue.safeParse(reference).success + ? await retrieveAnalytesReferences(residue, reference as ComplexResidue) + : residue; return { - ...residue, - label: reference ? SimpleResidueLabels[reference] : residue.label, + ...residueWithAnalyte, + label: SimpleResidue.safeParse(reference).success + ? SimpleResidueLabels[reference as SimpleResidue] + : ComplexResidue.safeParse(reference).success + ? ComplexResidueLabels[reference as ComplexResidue] + : (reference as string), reference, }; }) @@ -193,7 +174,7 @@ export const extractFromReport = async ( const completion = await openai.chat.completions.create({ model: 'gpt-4o-mini-2024-07-18', messages: extractAnalysisFromReportPrompt(content), - response_format: zodResponseFormat(AnalyseExtraction, 'analysis'), + response_format: zodResponseFormat(AnalysisExtraction, 'analysis'), }); console.log(completion.choices[0].message.content); @@ -202,5 +183,5 @@ export const extractFromReport = async ( completion.choices[0].message.content ?? '{}' ) as AnalyseExtraction; - return retrieveReferencesWithEmbeddings(analyseExtraction); + return retrieveResiduesReferences(analyseExtraction); };