diff --git a/Sources/TextRank/Sentence.swift b/Sources/TextRank/Sentence.swift index c4847cf..a9424a5 100644 --- a/Sources/TextRank/Sentence.swift +++ b/Sources/TextRank/Sentence.swift @@ -16,10 +16,11 @@ public struct Sentence: Hashable { public let originalTextIndex: Int - public init(text: String, originalTextIndex: Int) { + public init(text: String, originalTextIndex: Int, additionalStopwords: [String] = [String]()) { self.text = text self.originalTextIndex = originalTextIndex - words = Sentence.removeStopWords(from: Sentence.clean(self.text)) + words = Sentence.removeStopWords(from: Sentence.clean(self.text), + additionalStopwords: additionalStopwords) } public func hash(into hasher: inout Hasher) { @@ -37,9 +38,9 @@ public struct Sentence: Hashable { .words } - static func removeStopWords(from w: [String]) -> Set { + static func removeStopWords(from w: [String], additionalStopwords: [String] = [String]()) -> Set { var wordSet = Set(w) - wordSet.subtract(Stopwords.English) + wordSet.subtract(Stopwords.English + additionalStopwords) return wordSet } } diff --git a/Sources/TextRank/TextRank.swift b/Sources/TextRank/TextRank.swift index f4fdb22..2bf4cc2 100644 --- a/Sources/TextRank/TextRank.swift +++ b/Sources/TextRank/TextRank.swift @@ -10,14 +10,19 @@ import Foundation public class TextRank { public var text: String { didSet { - sentences = TextRank.splitIntoSentences(text).filter { $0.length > 0 } + textToSentences() } } - public var summarizationFraction: Float = 0.2 public var graph: TextGraph - public var graphDamping: Float = 0.85 public var sentences = [Sentence]() + public var summarizationFraction: Float = 0.2 + public var graphDamping: Float = 0.85 + public var stopwords = [String]() { + didSet { + textToSentences() + } + } public init() { text = "" @@ -26,16 +31,20 @@ public class TextRank { public init(text: String) { self.text = text - sentences = TextRank.splitIntoSentences(text).filter { $0.length > 0 } graph = TextGraph(damping: graphDamping) + textToSentences() } public init(text: String, summarizationFraction: Float = 0.2, graphDamping: Float = 0.85) { self.text = text self.summarizationFraction = summarizationFraction self.graphDamping = graphDamping - sentences = TextRank.splitIntoSentences(text).filter { $0.length > 0 } graph = TextGraph(damping: graphDamping) + textToSentences() + } + + func textToSentences() { + sentences = TextRank.splitIntoSentences(text, additionalStopwords: stopwords).filter { $0.length > 0 } } } @@ -78,13 +87,17 @@ extension TextRank { /// Split text into sentences. /// - Parameter text: Original text. /// - Returns: An array of sentences. - static func splitIntoSentences(_ text: String) -> [Sentence] { + static func splitIntoSentences(_ text: String, additionalStopwords stopwords: [String] = [String]()) -> [Sentence] { if text.isEmpty { return [] } var x = [Sentence]() text.enumerateSubstrings(in: text.range(of: text)!, options: [.bySentences, .localized]) { substring, _, _, _ in if let substring = substring, !substring.isEmpty { - x.append(Sentence(text: substring.trimmingCharacters(in: .whitespacesAndNewlines), originalTextIndex: x.count)) + x.append( + Sentence(text: substring.trimmingCharacters(in: .whitespacesAndNewlines), + originalTextIndex: x.count, + additionalStopwords: stopwords) + ) } } return Array(Set(x)) @@ -101,15 +114,12 @@ public extension TextRank { func filterTopSentencesFrom(_ results: TextGraph.PageRankResult, top percentile: Float) -> TextGraph.NodeList { let idx = Int(Float(results.results.count) * percentile) let cutoffScore: Float = results.results.values.sorted()[min(idx, results.results.count - 1)] - var filteredNodeList: TextGraph.NodeList = [:] - for (sentence, value) in results.results { if value >= cutoffScore { filteredNodeList[sentence] = value } } - return filteredNodeList } } diff --git a/Tests/TextRankTests/SentenceTests.swift b/Tests/TextRankTests/SentenceTests.swift index 78557a5..246a547 100644 --- a/Tests/TextRankTests/SentenceTests.swift +++ b/Tests/TextRankTests/SentenceTests.swift @@ -19,4 +19,38 @@ class SentenceTests: XCTestCase { XCTAssertEqual(s.words, Set(clean)) } } + + func testRemovalOfStopWords() { + // Given + let text = "here are some words to be" + + // When + let sentence = Sentence(text: text, originalTextIndex: 0) + + // Then + XCTAssertEqual(sentence.length, 0) + } + + func testRemovalOfStopWordsButNotMeaningfulWords() { + // Given + let text = "here are some words to be lion" + + // When + let sentence = Sentence(text: text, originalTextIndex: 0) + + // Then + XCTAssertEqual(sentence.length, 1) + XCTAssertEqual(sentence.words, Set(["lion"])) + } + + func testRemovalOfStopWordsAndAdditionalStopwords() { + // Given + let text = "here are some words to be lion" + + // When + let sentence = Sentence(text: text, originalTextIndex: 0, additionalStopwords: ["lion"]) + + // Then + XCTAssertEqual(sentence.length, 0) + } } diff --git a/Tests/TextRankTests/TextRankTests.swift b/Tests/TextRankTests/TextRankTests.swift index 6b2caa0..ef8f8bc 100644 --- a/Tests/TextRankTests/TextRankTests.swift +++ b/Tests/TextRankTests/TextRankTests.swift @@ -87,4 +87,40 @@ class TextRankTests: XCTestCase { XCTAssertTrue(filteredResults.count < results.results.count) XCTAssertTrue(filteredResults.count == 2) } + + func testStopwordsAreRemoved() { + // Given + let text = "Here are some sentences dog cat. With intentional stopwords gator. And some words that are not." + + // When + let textRank = TextRank(text: text) + + // Then + XCTAssertEqual(textRank.sentences.count, 2) + XCTAssertEqual(textRank.sentences[0].length, 3) + XCTAssertEqual(textRank.sentences.filter { $0.originalTextIndex == 0 }[0].words, + Set(["sentences", "dog", "cat"])) + XCTAssertEqual(textRank.sentences.filter { $0.originalTextIndex == 1 }[0].words, + Set(["intentional", "stopwords", "gator"])) + XCTAssertEqual(textRank.sentences[1].length, 3) + } + + func testAdditionalStopwords() { + // Given + let text = "Here are some sentences dog cat. With intentional stopwords gator. And some words that are not." + let additionalStopwords = ["dog", "gator"] + + // When + let textRank = TextRank(text: text) + textRank.stopwords = additionalStopwords + + // Then + XCTAssertEqual(textRank.sentences.count, 2) + XCTAssertEqual(textRank.sentences[0].length, 2) + XCTAssertEqual(textRank.sentences.filter { $0.originalTextIndex == 0 }[0].words, + Set(["sentences", "cat"])) + XCTAssertEqual(textRank.sentences.filter { $0.originalTextIndex == 1 }[0].words, + Set(["intentional", "stopwords"])) + XCTAssertEqual(textRank.sentences[1].length, 2) + } }