Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 批量关联问题 #1235

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def is_save_function():
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id}')

@staticmethod
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)

@staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
Expand Down
74 changes: 73 additions & 1 deletion apps/dataset/serializers/problem_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
import os
import uuid
from functools import reduce
from typing import Dict, List

from django.db import transaction
Expand All @@ -21,7 +22,8 @@
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding
from embedding.models import SourceType
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
from smartdoc.conf import PROJECT_DIR


Expand Down Expand Up @@ -50,6 +52,35 @@ def get_request_body_api():
})


class AssociationParagraph(serializers.Serializer):
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))


class BatchAssociation(serializers.Serializer):
problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题id列表"),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid("问题id")))
paragraph_list = AssociationParagraph(many=True)


def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
exits_problem_paragraph_mapping_list if
str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id]
return len(filter_list) > 0


def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str):
return ProblemParagraphMapping(id=uuid.uuid1(),
document_id=document_id,
paragraph_id=paragraph_id,
dataset_id=dataset_id,
problem_id=str(problem.id)), problem


class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
Expand Down Expand Up @@ -114,6 +145,47 @@ def delete(self, problem_id_list: List, with_valid=True):
delete_embedding_by_source_ids(source_ids)
return True

def association(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
BatchAssociation(data=instance).is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
paragraph_list = instance.get('paragraph_list')
problem_id_list = instance.get('problem_id_list')
problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list,
paragraph_id__in=[
p.get('paragraph_id')
for p in
paragraph_list])
problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for
problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y],
[[
to_problem_paragraph_mapping(
problem,
paragraph.get(
'document_id'),
paragraph.get(
'paragraph_id'),
dataset_id) for
paragraph in
paragraph_list]
for problem in
problem_list], []) if
not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)]
QuerySet(ProblemParagraphMapping).bulk_create(
[problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list])
data_list = [{'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': str(problem_paragraph_mapping.id),
'document_id': str(problem_paragraph_mapping.document_id),
'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
'dataset_id': dataset_id,
} for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
embedding_by_data_list(data_list, model_id=model_id)

class Operate(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))

Expand Down
30 changes: 30 additions & 0 deletions apps/dataset/swagger_api/problem_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@ def get_response_body_api():
}
)

class BatchAssociation(ApiMixin):
@staticmethod
def get_request_params_api():
return ProblemApi.BatchOperate.get_request_params_api()

@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['problem_id_list'],
properties={
'problem_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题id列表",
description="问题id列表",
items=openapi.Schema(type=openapi.TYPE_STRING)),
'paragraph_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="关联段落信息列表",
description="关联段落信息列表",
items=openapi.Schema(type=openapi.TYPE_OBJECT,
required=['paragraph_id', 'document_id'],
properties={
'paragraph_id': openapi.Schema(
type=openapi.TYPE_STRING,
title="段落id"),
'document_id': openapi.Schema(
type=openapi.TYPE_STRING,
title="文档id")
}))

}
)

class BatchOperate(ApiMixin):
@staticmethod
def get_request_params_api():
Expand Down
14 changes: 14 additions & 0 deletions apps/dataset/views/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ def delete(self, request: Request, dataset_id: str):
return result.success(
ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).delete(request.data))

@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量关联段落",
operation_id="批量关联段落",
request_body=ProblemApi.BatchAssociation.get_request_body_api(),
manual_parameters=ProblemApi.BatchOperate.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(
ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).association(request.data))

class Operate(APIView):
authentication_classes = [TokenAuth]

Expand Down
5 changes: 5 additions & 0 deletions apps/embedding/task/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def embedding_by_problem(args, model_id):
ListenerManagement.embedding_by_problem(args, embedding_model)


def embedding_by_data_list(args: List, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_data_list(args, embedding_model)


def delete_embedding_by_document(document_id):
"""
删除指定文档id的向量
Expand Down
36 changes: 15 additions & 21 deletions apps/embedding/vector/base_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,21 @@ def save(self, text, source_type: SourceType, dataset_id: str, document_id: str,
self._batch_save(child_array, embedding, lambda: True)

def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function):
# 获取锁
lock.acquire()
try:
"""
批量插入
:param data_list: 数据列表
:param embedding: 向量化处理器
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
else:
break
finally:
# 释放锁
lock.release()
return True
"""
批量插入
@param data_list: 数据列表
@param embedding: 向量化处理器
@param is_save_function:
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
else:
break

@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
Expand Down
19 changes: 18 additions & 1 deletion ui/src/api/problem.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,28 @@ const getDetailProblems: (
return get(`${prefix}/${dataset_id}/problem/${problem_id}/paragraph`, undefined, loading)
}

/**
* 批量关联段落
* @param 参数 dataset_id,
* {
"problem_id_list": "Array",
"paragraph_list": "Array",
}
*/
const postMulAssociationProblem: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/problem/_batch`, data, undefined, loading)
}

export default {
getProblems,
postProblems,
delProblems,
putProblems,
getDetailProblems,
delMulProblem
delMulProblem,
postMulAssociationProblem
}
90 changes: 64 additions & 26 deletions ui/src/views/problem/component/RelateProblemDialog.vue
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@
</el-scrollbar>
</el-col>
</el-row>
<template #footer v-if="isMul">
<div class="dialog-footer">
<el-button @click="dialogVisible = false"> 取消</el-button>
<el-button type="primary" @click="mulAssociation"> 确认 </el-button>
</div>
</template>
</el-dialog>
</template>
<script setup lang="ts">
Expand All @@ -99,6 +105,7 @@ import { useRoute } from 'vue-router'
import problemApi from '@/api/problem'
import paragraphApi from '@/api/paragraph'
import useStore from '@/stores'
import { MsgSuccess } from '@/utils/message'

const { problem, document } = useStore()

Expand All @@ -116,6 +123,7 @@ const documentList = ref<any[]>([])
const cloneDocumentList = ref<any[]>([])
const paragraphList = ref<any[]>([])
const currentProblemId = ref<String>('')
const currentMulProblemId = ref<string[]>([])

// 回显
const associationParagraph = ref<any[]>([])
Expand All @@ -124,38 +132,62 @@ const currentDocument = ref<String>('')
const search = ref('')
const searchType = ref('title')
const filterDoc = ref('')
// 批量
const isMul = ref(false)

const paginationConfig = reactive({
current_page: 1,
page_size: 50,
total: 0
})

function mulAssociation() {
const data = {
problem_id_list: currentMulProblemId.value,
paragraph_list: associationParagraph.value.map((item) => ({
paragraph_id: item.id,
document_id: item.document_id
}))
}
problemApi.postMulAssociationProblem(id, data, loading).then(() => {
MsgSuccess('批量关联分段成功')
dialogVisible.value = false
})
}

function associationClick(item: any) {
if (isAssociation(item.id)) {
problem
.asyncDisassociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
if (isMul.value) {
if (isAssociation(item.id)) {
associationParagraph.value.splice(associationParagraph.value.indexOf(item.id), 1)
} else {
associationParagraph.value.push(item)
}
} else {
problem
.asyncAssociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
if (isAssociation(item.id)) {
problem
.asyncDisassociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
} else {
problem
.asyncAssociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
}
}
}

Expand Down Expand Up @@ -216,6 +248,7 @@ watch(dialogVisible, (bool) => {
cloneDocumentList.value = []
paragraphList.value = []
associationParagraph.value = []
isMul.value = false

currentDocument.value = ''
search.value = ''
Expand All @@ -232,10 +265,15 @@ watch(filterDoc, (val) => {
currentDocument.value = documentList.value?.length > 0 ? documentList.value[0].id : ''
})

const open = (problemId: string) => {
currentProblemId.value = problemId
const open = (problemId: any) => {
getDocument()
getRecord(problemId)
if (problemId.length == 1) {
currentProblemId.value = problemId[0]
getRecord(problemId)
} else if (problemId.length > 1) {
currentMulProblemId.value = problemId
isMul.value = true
}
dialogVisible.value = true
}

Expand Down
Loading
Loading