Skip to content

Commit

Permalink
Enable uploading list of extra info
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Sep 7, 2023
1 parent 4954297 commit 861e685
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
14 changes: 11 additions & 3 deletions abcd/backends/atoms_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def save_bulk(self, actions: Iterable):
def push(
self,
atoms: Union[Atoms, Iterable],
extra_info: Union[dict, str, None] = None,
extra_info: Union[dict, str, list, None] = None,
store_calc: bool = True,
):
"""
Expand All @@ -406,6 +406,10 @@ def push(
"""
if extra_info and isinstance(extra_info, str):
extra_info = extras.parser.parse(extra_info) # type: ignore
if extra_info and isinstance(extra_info, list):
for i, info in enumerate(extra_info):
if isinstance(info, str):
extra_info[i] = extras.parser.parse(info)

if isinstance(atoms, Atoms):
data = AtomsModel.from_atoms(
Expand All @@ -419,12 +423,16 @@ def push(

elif isinstance(atoms, Generator) or isinstance(atoms, list):
actions = []
for item in atoms:
for i, item in enumerate(atoms):
if isinstance(extra_info, list):
info = extra_info[i]
else:
info = extra_info
data = AtomsModel.from_atoms(
self.client,
self.index_name,
item,
extra_info=extra_info, # type: ignore
extra_info=info, # type: ignore
store_calc=store_calc,
)
actions.append(data.data)
Expand Down
8 changes: 6 additions & 2 deletions abcd/backends/atoms_pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,13 @@ def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True):
# self.collection.insert_one(data)

elif isinstance(atoms, types.GeneratorType) or isinstance(atoms, list):
for item in atoms:
for i, item in enumerate(atoms):
if isinstance(extra_info, list):
info = extra_info[i]
else:
info = extra_info
data = AtomsModel.from_atoms(
self.collection, item, extra_info=extra_info, store_calc=store_calc
self.collection, item, extra_info=info, store_calc=store_calc
)
data.save()

Expand Down

0 comments on commit 861e685

Please sign in to comment.