Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch to loomcat #247

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ docs/_build/
*_minted-*
*.pdf

# vim artifacts
*.swp

# specific directories
examples/malawi/resources/
examples/satellites/splits/
Expand Down
19 changes: 11 additions & 8 deletions src/crosscat/loomcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,20 @@ def _retrieve_column_partition(path, sample):
]))


def _retrieve_featureid_to_cgpm(path):
def _retrieve_featureid_to_cgpm(path, colname_colno_mapping=None):
"""Returns a dict mapping loom's 0-based featureid to cgpm.outputs."""
# Loom orders features alphabetically based on statistical types:
# i.e. 'bb' < 'dd' < 'nich'. The ordering is stored in
# `ingest/encoding.json.gz`.
encoding_in = os.path.join(path, 'ingest', 'encoding.json.gz')
features = json_load(encoding_in)
def colname_to_output(cname):
# Convert dummy column name from 'c00012' to the integer 12.
return int(cname.replace('c', ''))
if colname_colno_mapping is not None:
def colname_to_output(cname):
return colname_colno_mapping[cname]
else:
def colname_to_output(cname):
# Convert dummy column name from 'c00012' to the integer 12.
return int(cname.replace('c', ''))
return {
i: colname_to_output(f['name']) for i, f in enumerate(features)
}
Expand All @@ -136,7 +140,7 @@ def _retrieve_row_partitions(path, sample):
}


def _update_state(state, path, sample):
def _update_state(state, path, sample, colname_colno_mapping=None):
"""Updates `state` to match the CrossCat `sample` at `path`.

Only the row and column partitions are updated; parameter inference
Expand All @@ -145,15 +149,14 @@ def _update_state(state, path, sample):

Wild errors will occur if the Loom object is incompatible with `state`.
"""

# Retrieve the new column partition from loom.
Zv_new_raw = _retrieve_column_partition(path, sample)
assert sorted(Zv_new_raw.keys()) == range(len(state.outputs))

# The keys of Zv are contiguous
# from [0..len(outputs)], while state.outputs are arbitrary integers, so we
# need to map the loom feature ids correctly.
output_mapping = _retrieve_featureid_to_cgpm(path)
output_mapping = _retrieve_featureid_to_cgpm(path, colname_colno_mapping)

assert sorted(output_mapping.values()) == sorted(state.outputs)
Zv_new = {output_mapping[f]: Zv_new_raw[f] for f in Zv_new_raw}

Expand Down