From fa5ad86564c0e499570133600dc21a25038e928c Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Tue, 27 Sep 2022 09:14:27 -0400 Subject: [PATCH] Fix console error with compressed indexes, closes #347 --- src/python/txtai/console/base.py | 96 +++++++++++++++++++------------- test/python/testconsole.py | 8 ++- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/src/python/txtai/console/base.py b/src/python/txtai/console/base.py index a30c430ed..4459f2f1b 100644 --- a/src/python/txtai/console/base.py +++ b/src/python/txtai/console/base.py @@ -99,24 +99,6 @@ def config(self): self.console.print(self.app.config) - def load(self, path): - """ - Processes .load command. - - Args: - path: path to configuration - """ - - if os.path.isfile(path): - self.console.print(f"Loading application {path}") - self.app = Application(path) - else: - self.console.print(f"Loading index {path}") - - # Load embeddings index - self.app = Embeddings() - self.app.load(path) - def highlight(self, command): """ Processes .highlight command. @@ -141,32 +123,23 @@ def limit(self, command): self.vlimit = int(action) self.console.print(f"Set limit to {self.vlimit}") - def workflow(self, command): + def load(self, path): """ - Processes .workflow command. + Processes .load command. Args: - command: command line - """ - - command = shlex.split(command) - if isinstance(self.app, Application): - self.console.print(list(self.app.workflow(command[1], command[2:]))) - - def split(self, command, default=None): + path: path to configuration """ - Splits command by whitespace. - Args: - command: command line - default: default command action - - Returns: - command action - """ + if self.isyaml(path): + self.console.print(f"Loading application {path}") + self.app = Application(path) + else: + self.console.print(f"Loading index {path}") - values = command.split(" ", 1) - return values if len(values) > 1 else (command, default) + # Load embeddings index + self.app = Embeddings() + self.app.load(path) def search(self, query): """ @@ -205,6 +178,53 @@ def search(self, query): # Print table to console self.console.print(table) + def workflow(self, command): + """ + Processes .workflow command. + + Args: + command: command line + """ + + command = shlex.split(command) + if isinstance(self.app, Application): + self.console.print(list(self.app.workflow(command[1], command[2:]))) + + def isyaml(self, path): + """ + Checks if file at path is a valid YAML file. + + Args: + path: file to check + + Returns: + True if file is valid YAML, False otherwise + """ + + if os.path.exists(path) and os.path.isfile(path): + try: + return Application.read(path) + # pylint: disable=W0702 + except: + pass + + return False + + def split(self, command, default=None): + """ + Splits command by whitespace. + + Args: + command: command line + default: default command action + + Returns: + command action + """ + + values = command.split(" ", 1) + return values if len(values) > 1 else (command, default) + def render(self, result, column, value): """ Renders a search result column value. diff --git a/test/python/testconsole.py b/test/python/testconsole.py index f71444d8f..0df82bb7d 100644 --- a/test/python/testconsole.py +++ b/test/python/testconsole.py @@ -55,8 +55,9 @@ def setUpClass(cls): with open(cls.apppath, "w", encoding="utf-8") as out: out.write(APPLICATION % cls.embedpath) - # Save index + # Save index as uncompressed and compressed cls.embeddings.save(cls.embedpath) + cls.embeddings.save(f"{cls.embedpath}.tar.gz") # Create console cls.console = Console(cls.embedpath) @@ -66,7 +67,7 @@ def testApplication(self): Test application """ - self.assertIn("console.yml", self.command(f".load {self.apppath}")) + self.assertNotIn("Traceback", self.command(f".load {self.apppath}")) self.assertIn("1", self.command(".limit 1")) self.assertIn("Maine man wins", self.command("feel good story")) @@ -82,7 +83,8 @@ def testEmbeddings(self): Test embeddings index """ - self.assertIn("embeddings", self.command(f".load {self.embedpath}")) + self.assertNotIn("Traceback", self.command(f".load {self.embedpath}.tar.gz")) + self.assertNotIn("Traceback", self.command(f".load {self.embedpath}")) self.assertIn("1", self.command(".limit 1")) self.assertIn("Maine man wins", self.command("feel good story"))