Skip to content

Commit

Permalink
Merge pull request #65 from stephantul/64-feat-add-support-for-legacy…
Browse files Browse the repository at this point in the history
…-fast-format

Add legacy loading
  • Loading branch information
stephantul authored Jul 14, 2024
2 parents 894a4b4 + a72d0e2 commit 4ad5db9
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 12 deletions.
Empty file added reach/legacy/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions reach/legacy/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json
from pathlib import Path

import numpy as np
import numpy.typing as npt


def load_old_fast_format_data(
path: Path,
) -> tuple[npt.NDArray, list[str], str | None, str]:
"""Load data from fast format."""
with open(f"{path}_items.json") as file_handle:
items = json.load(file_handle)
tokens, unk_index, name = items["items"], items["unk_index"], items["name"]

with open(f"{path}_vectors.npy", "rb") as file_handle:
vectors = np.load(file_handle)

unk_token = tokens[unk_index] if unk_index is not None else None
return vectors, tokens, unk_token, name
41 changes: 29 additions & 12 deletions reach/reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from numpy import typing as npt
from tqdm import tqdm

from reach.legacy.load import load_old_fast_format_data


Dtype: TypeAlias = str | np.dtype
File = Path | TextIOWrapper
PathLike = str | Path
Expand Down Expand Up @@ -1041,20 +1044,34 @@ def load_fast_format(
"""
filename_path = Path(filename)
with open(filename) as file_handle:
data: dict[str, Any] = json.load(file_handle)
items: list[str] = data["items"]

metadata: dict[str, Any] = data["metadata"]
unk_token = metadata.pop("unk_token")
name = metadata.pop("name")
numpy_path = filename_path.parent / Path(data["vectors_path"])

if not numpy_path.exists():
raise ValueError(f"Could not find the vectors file at {numpy_path}")
try:
with open(filename) as file_handle:
data: dict[str, Any] = json.load(file_handle)
items: list[str] = data["items"]

metadata: dict[str, Any] = data["metadata"]
unk_token = metadata.pop("unk_token")
name = metadata.pop("name")
numpy_path = filename_path.parent / Path(data["vectors_path"])

if not numpy_path.exists():
raise ValueError(f"Could not find the vectors file at {numpy_path}")

with open(numpy_path, "rb") as file_handle:
vectors: npt.NDArray = np.load(file_handle)
except FileNotFoundError as exc:
logger.warning("Attempting to load from old format.")
try:
vectors, items, unk_token, name = load_old_fast_format_data(
filename_path
)
metadata = {}
except FileNotFoundError:
logger.warning("Loading from old format failed")
# NOTE: reraise old exception.
raise exc

with open(numpy_path, "rb") as file_handle:
vectors: npt.NDArray = np.load(file_handle)
vectors = vectors.astype(desired_dtype)
instance = cls(vectors, items, name=name, metadata=metadata)
instance.unk_token = unk_token
Expand Down
45 changes: 45 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
Expand Down Expand Up @@ -249,6 +250,50 @@ def test_save_load_fast_format(self) -> None:
self.assertEqual(instance._unk_index, instance_2._unk_index)
self.assertEqual(instance.name, instance_2.name)

def test_save_load_fast_format_old(self) -> None:
with TemporaryDirectory() as temp_folder:
lines = self.lines()

temp_folder_path = Path(temp_folder)

temp_file_name = temp_folder_path / "test.vec"
with open(temp_file_name, "w") as tempfile:
tempfile.write(lines)
tempfile.seek(0)

instance = Reach.load(temp_file_name)
fast_format_file = temp_folder_path / "temp"

items_dict = {
"items": instance.sorted_items,
"unk_index": instance._unk_index,
"name": instance.name,
}

json.dump(items_dict, open(f"{fast_format_file}_items.json", "w"))
np.save(f"{fast_format_file}_vectors.npy", instance.vectors)

instance_2 = Reach.load_fast_format(fast_format_file)

self.assertEqual(instance.size, instance_2.size)
self.assertEqual(len(instance), len(instance_2))
self.assertEqual(instance.items, instance_2.items)
self.assertTrue(np.allclose(instance.vectors, instance_2.vectors))
self.assertEqual(instance._unk_index, instance_2._unk_index)
self.assertEqual(instance.name, instance_2.name)

fast_format_file_2 = temp_folder_path / "temp.reach"

instance.save_fast_format(fast_format_file_2)
instance_3 = Reach.load_fast_format(fast_format_file_2)

self.assertEqual(instance.size, instance_3.size)
self.assertEqual(len(instance), len(instance_3))
self.assertEqual(instance.items, instance_3.items)
self.assertTrue(np.allclose(instance.vectors, instance_3.vectors))
self.assertEqual(instance._unk_index, instance_3._unk_index)
self.assertEqual(instance.name, instance_3.name)

def test_save_load(self) -> None:
with NamedTemporaryFile("w+") as tempfile:
lines = self.lines()
Expand Down

0 comments on commit 4ad5db9

Please sign in to comment.