Skip to content

Commit

Permalink
Move slack to chunker v2 structure (#2744)
Browse files Browse the repository at this point in the history
* Revert "Revert "Move slack to chunker v2 structure (#2741)" (#2743)"

This reverts commit 45e89cf.

* Fix core
  • Loading branch information
spolu authored Dec 1, 2023
1 parent 06c6934 commit 134730a
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 60 deletions.
30 changes: 19 additions & 11 deletions connectors/src/connectors/slack/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
ConversationType,
ModelId,
RetrievalDocumentType,
sectionFullText,
UserMessageType,
} from "@dust-tt/types";
import { WebClient } from "@slack/web-api";
Expand Down Expand Up @@ -688,12 +689,27 @@ async function makeContentFragment(
return new Ok(null);
}

const text = await formatMessagesForUpsert(
const channel = await slackClient.conversations.info({
channel: channelId,
});

if (channel.error) {
return new Err(
new Error(`Could not retrieve channel name: ${channel.error}`)
);
}
if (!channel.channel || !channel.channel.name) {
return new Err(new Error("Could not retrieve channel name"));
}

const content = await formatMessagesForUpsert(
channelId,
channel.channel.name,
allMessages,
connector.id,
slackClient
);

let url: string | null = null;
if (allMessages[0]?.ts) {
const permalinkRes = await slackClient.chat.getPermalink({
Expand All @@ -705,17 +721,9 @@ async function makeContentFragment(
}
url = permalinkRes.permalink;
}
const channel = await slackClient.conversations.info({
channel: channelId,
});

if (channel.error) {
return new Err(new Error(channel.error));
}

return new Ok({
title: `Thread content from #${channel.channel?.name}`,
content: text,
title: `Thread content from #${channel.channel.name}`,
content: sectionFullText(content),
url: url,
contentType: "slack_thread_content",
context: null,
Expand Down
84 changes: 51 additions & 33 deletions connectors/src/connectors/slack/temporal/activities.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ModelId } from "@dust-tt/types";
import { CoreAPIDataSourceDocumentSection, ModelId } from "@dust-tt/types";
import {
CodedError,
ErrorCode,
Expand Down Expand Up @@ -405,8 +405,9 @@ export async function syncNonThreaded(
}
messages.reverse();

const text = await formatMessagesForUpsert(
const content = await formatMessagesForUpsert(
channelId,
channelName,
messages,
connectorId,
client
Expand Down Expand Up @@ -435,20 +436,18 @@ export async function syncNonThreaded(
: undefined;

const tags = getTagsForPage(documentId, channelId, channelName);

await SlackMessages.upsert({
connectorId: connectorId,
channelId: channelId,
messageTs: undefined,
documentId: documentId,
});

await upsertToDatasource({
dataSourceConfig,
documentId,
documentContent: {
prefix: null,
content: text,
sections: [],
},
documentContent: content,
documentUrl: sourceUrl,
timestampMs: createdAt,
tags,
Expand Down Expand Up @@ -582,8 +581,9 @@ export async function syncThread(
return;
}

const text = await formatMessagesForUpsert(
const content = await formatMessagesForUpsert(
channelId,
channelName,
allMessages,
connectorId,
slackClient
Expand Down Expand Up @@ -618,11 +618,7 @@ export async function syncThread(
await upsertToDatasource({
dataSourceConfig,
documentId,
documentContent: {
prefix: null,
content: text,
sections: [],
},
documentContent: content,
documentUrl: sourceUrl,
timestampMs: createdAt,
tags,
Expand Down Expand Up @@ -659,31 +655,53 @@ async function processMessageForMentions(

export async function formatMessagesForUpsert(
channelId: string,
channelName: string,
messages: MessageElement[],
connectorId: ModelId,
slackClient: WebClient
) {
return (
await Promise.all(
messages.map(async (message) => {
const text = await processMessageForMentions(
message.text as string,
connectorId,
slackClient
);
): Promise<CoreAPIDataSourceDocumentSection> {
const data = await Promise.all(
messages.map(async (message) => {
const text = await processMessageForMentions(
message.text as string,
connectorId,
slackClient
);

const userName = await getUserName(
message.user as string,
connectorId,
slackClient
);
const messageDate = new Date(parseInt(message.ts as string, 10) * 1000);
const messageDateStr = formatDateForUpsert(messageDate);
const userName = await getUserName(
message.user as string,
connectorId,
slackClient
);
const messageDate = new Date(parseInt(message.ts as string, 10) * 1000);
const messageDateStr = formatDateForUpsert(messageDate);

return {
dateStr: messageDateStr,
userName,
text: text,
content: text + "\n",
sections: [],
};
})
);

return `>> @${userName} [${messageDateStr}]:\n${text}\n`;
})
)
).join("\n");
const first = data[0];
if (!first) {
throw new Error("Cannot format empty list of messages");
}

return {
prefix: `Thread in #${channelName} [${first.dateStr}]: ${
first.text.replace(/\s+/g, " ").trim().substring(0, 128) + "..."
}\n`,
content: null,
sections: data.map((d) => ({
prefix: `>> @${d.userName} [${d.dateStr}]:\n`,
content: d.text + "\n",
sections: [],
})),
};
}

export async function fetchUsers(connectorId: ModelId) {
Expand Down
91 changes: 75 additions & 16 deletions core/src/data_sources/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async fn split_text(
#[derive(Debug, Clone)]
pub struct TokenizedSection {
pub max_chunk_size: usize,
pub prefixes: Vec<TokenizedText>,
pub prefixes: Vec<(String, TokenizedText)>,
pub tokens_count: usize,
pub content: Option<TokenizedText>,
pub sections: Vec<TokenizedSection>,
Expand All @@ -141,9 +141,15 @@ impl TokenizedSection {
pub async fn from(
embedder: &Box<dyn Embedder + Sync + Send>,
max_chunk_size: usize,
mut prefixes: Vec<TokenizedText>,
mut prefixes: Vec<(String, TokenizedText)>,
section: &Section,
path: Option<String>,
) -> Result<Self> {
let path = match path.as_ref() {
Some(p) => p,
None => "",
};

let (prefix, mut content) = try_join!(
TokenizedText::from(embedder, section.prefix.as_ref()),
TokenizedText::from(embedder, section.content.as_ref())
Expand All @@ -152,12 +158,12 @@ impl TokenizedSection {
// Add the new prefix to the list of prefixes to be passed down children.
match prefix.as_ref() {
Some(prefix) => {
prefixes.push(prefix.clone());
prefixes.push((path.to_string(), prefix.clone()));
}
None => (),
};

let prefixes_tokens_count = prefixes.iter().map(|p| p.tokens.len()).sum::<usize>();
let prefixes_tokens_count = prefixes.iter().map(|(_, p)| p.tokens.len()).sum::<usize>();
if prefixes_tokens_count >= max_chunk_size / 2 {
Err(anyhow!(
"Could not tokenize the provided document,
Expand Down Expand Up @@ -201,12 +207,15 @@ impl TokenizedSection {
}

sections.extend(
futures::future::join_all(
section
.sections
.iter()
.map(|s| TokenizedSection::from(embedder, max_chunk_size, prefixes.clone(), s)),
)
futures::future::join_all(section.sections.iter().enumerate().map(|(i, s)| {
TokenizedSection::from(
embedder,
max_chunk_size,
prefixes.clone(),
s,
Some(format!("{}{}", path, i)),
)
}))
.await
.into_iter()
.collect::<Result<Vec<_>>>()?,
Expand Down Expand Up @@ -250,9 +259,9 @@ impl TokenizedSection {
let mut tokens: Vec<usize> = vec![];

for s in self.dfs() {
s.prefixes.iter().for_each(|p| {
if !seen_prefixes.contains(&p.text) {
seen_prefixes.insert(p.text.clone());
s.prefixes.iter().for_each(|(h, p)| {
if !seen_prefixes.contains(h) {
seen_prefixes.insert(h.clone());
tokens.extend(p.tokens.clone());
text += &p.text;
}
Expand Down Expand Up @@ -298,8 +307,11 @@ impl TokenizedSection {
let prefixes = self.prefixes.clone();
assert!(self.content.is_none());

let prefixes_tokens_count =
self.prefixes.iter().map(|p| p.tokens.len()).sum::<usize>();
let prefixes_tokens_count = self
.prefixes
.iter()
.map(|(_, p)| p.tokens.len())
.sum::<usize>();

let mut selection: Vec<TokenizedSection> = vec![];
let mut selection_tokens_count: usize = prefixes_tokens_count;
Expand Down Expand Up @@ -432,7 +444,7 @@ impl Splitter for BaseV0Splitter {
embedder.initialize(credentials).await?;

let tokenized_section =
TokenizedSection::from(&embedder, max_chunk_size, vec![], &section).await?;
TokenizedSection::from(&embedder, max_chunk_size, vec![], &section, None).await?;

// We filter out whitespace only or empty strings which is possible to obtain if the section
// passed have empty or whitespace only content.
Expand Down Expand Up @@ -850,4 +862,51 @@ mod tests {
.join("|")
)
}

#[tokio::test]
async fn test_splitter_v0_bug_20231201() {
let section = Section {
prefix: Some(
"Thread in #brand [20230908 10:16]: Should we make a poster?...\n".to_string(),
),
content: None,
sections: vec![
Section {
prefix: Some(">> @ed [20230908 10:16]:\n".to_string()),
content: Some("Should we make a poster?\n".to_string()),
sections: vec![],
},
Section {
prefix: Some(">> @spolu [20230908 10:16]:\n".to_string()),
content: Some(":100:\n".to_string()),
sections: vec![],
},
Section {
prefix: Some(">> @spolu [20230908 10:16]:\n".to_string()),
content: Some("\"Factory\" :p\n".to_string()),
sections: vec![],
},
],
};

let provider_id = ProviderID::OpenAI;
let model_id = "text-embedding-ada-002";
let credentials = Credentials::from([("OPENAI_API_KEY".to_string(), "abc".to_string())]);

let splitted = splitter(SplitterID::BaseV0)
.split(credentials, provider_id, model_id, 256, section)
.await
.unwrap();

// Before the bug the second @spolu prefix would be skipped because we were doing string
// matching vs prefix position matching.

assert_eq!(
splitted.join("|"),
"Thread in #brand [20230908 10:16]: Should we make a poster?...\n\
>> @ed [20230908 10:16]:\nShould we make a poster?\n\
>> @spolu [20230908 10:16]:\n:100:\n\
>> @spolu [20230908 10:16]:\n\"Factory\" :p\n"
)
}
}

0 comments on commit 134730a

Please sign in to comment.