Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow selection from index map values #297

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions draco/analysis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,22 +1354,35 @@ def process(self, data: containers.ContainerBase) -> containers.ContainerBase:
Container of same type as the input with specific axis selections.
Any datasets not included in the selections will not be initialized.
"""
# Re-format selections to only use axis name
for ax_sel in list(self._sel):
ax = ax_sel.replace("_sel", "")
self._sel[ax] = self._sel.pop(ax_sel)
sel = {}

# Parse axes with selections and reformat to use only
# the axis name
for k in self.selections:
*axis, type_ = k.split("_")
axis_name = "_".join(axis)

ax_sel = self._sel.get(f"{axis_name}_sel")

if type_ == "map":
# Use index map to get the correct axis indices
imap = list(data.index_map[axis_name])
ax_sel = [imap.index(x) for x in ax_sel]

if ax_sel is not None:
sel[axis_name] = ax_sel

# Figure out the axes for the new container and
# Apply the downselections to each axis index_map
output_axes = {
ax: mpiarray._apply_sel(data.index_map[ax], sel, 0)
for ax, sel in self._sel.items()
ax: mpiarray._apply_sel(data.index_map[ax], ax_sel, 0)
for ax, ax_sel in sel.items()
}
# Create the output container without initializing any datasets.
out = data.__class__(
axes_from=data, attrs_from=data, skip_datasets=True, **output_axes
)
containers.copy_datasets_filter(data, out, selection=self._sel)
containers.copy_datasets_filter(data, out, selection=sel)

return out

Expand Down
36 changes: 27 additions & 9 deletions draco/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os.path
import shutil
import subprocess
from functools import partial
from typing import ClassVar, Optional, Union

import numpy as np
Expand Down Expand Up @@ -351,6 +352,11 @@ class SelectionsMixin:
selections : dict, optional
A dictionary of axis selections. See below for details.

allow_index_map : bool, optional
If true, selections can be made based on an index_map dataset.
This cannot be implemented when reading from disk. See below for
details. Default is False.

Selections
----------
Selections can be given to limit the data read to specified subsets. They can be
Expand All @@ -359,11 +365,13 @@ class SelectionsMixin:
Selections can be given as a slice with an `<axis name>_range` key with either
`[start, stop]` or `[start, stop, step]` as the value. Alternatively a list of
explicit indices to extract can be given with the `<axis name>_index` key, and
the value is a list of the indices. If both `<axis name>_range` and `<axis
name>_index` keys are given the former will take precedence, but you should
clearly avoid doing this.
the value is a list of the indices. Finally, selection based on an `index_map`
can be given with specific `index_map` entries with the `<axis name>_map` key,
which will be converted to axis indices. `<axis name>_range` will take precedence
over `<axis name>_index`, which will in turn take precedence over `<axis_name>_map`,
but you should clearly avoid doing this.

Additionally index based selections currently don't work for distributed reads.
Additionally, index-based selections currently don't work for distributed reads.

Here's an example in the YAML format that the pipeline uses:

Expand All @@ -373,9 +381,11 @@ class SelectionsMixin:
freq_range: [256, 512, 4] # A strided slice
stack_index: [1, 2, 4, 9, 16, 25, 36, 49, 64] # A sparse selection
stack_range: [1, 14] # Will override the selection above
pol_map: ["XX", "YY"] # Select the indices corresponding to these entries
"""

selections = config.Property(proptype=dict, default=None)
allow_index_map = config.Property(proptype=bool, default=False)

def setup(self):
"""Resolve the selections."""
Expand All @@ -386,7 +396,14 @@ def _resolve_sel(self):

sel = {}

sel_parsers = {"range": self._parse_range, "index": self._parse_index}
sel_parsers = {
"range": self._parse_range,
"index": partial(self._parse_index, type_=int),
"map": self._parse_index,
}

if not self.allow_index_map:
del sel_parsers["map"]

# To enforce the precedence of range vs index selections, we rely on the fact
# that a sort will place the axis_range keys after axis_index keys
Expand All @@ -398,7 +415,8 @@ def _resolve_sel(self):

if type_ not in sel_parsers:
raise ValueError(
f'Unsupported selection type "{type_}", or invalid key "{k}"'
f'Unsupported selection type "{type_}", or invalid key "{k}". '
"Note that map-type selections require `allow_index_map=True`."
)

sel[f"{axis_name}_sel"] = sel_parsers[type_](self.selections[k])
Expand All @@ -419,15 +437,15 @@ def _parse_range(self, x):

return slice(*x)

def _parse_index(self, x):
def _parse_index(self, x, type_=object):
# Parse and validate an index type selection

if not isinstance(x, (list, tuple)) or len(x) == 0:
raise ValueError(f"Index spec must be a non-empty list or tuple. Got {x}.")

for v in x:
if not isinstance(v, int):
raise ValueError(f"All elements of index spec must be ints. Got {x}")
if not isinstance(v, type_):
raise ValueError(f"All elements of index spec must be {type_}. Got {x}")

return list(x)

Expand Down