Skip to content

Commit

Permalink
Dev (#302)
Browse files Browse the repository at this point in the history
* formatting

* use cache

* skip test_dcnt_r01

* remove separator

* Add numpy as np
  • Loading branch information
tanghaibao authored Jun 18, 2024
1 parent 2b47d8e commit d0ba53e
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 41 deletions.
91 changes: 59 additions & 32 deletions goatools/godag/go_tasks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""item-DAG tasks."""

__copyright__ = "Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved."
__copyright__ = (
"Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved."
)
__author__ = "DV Klopfenstein"

from goatools.godag.consts import RELATIONSHIP_SET
from ..godag.consts import RELATIONSHIP_SET


# ------------------------------------------------------------------------------------
def get_go2parents(go2obj, relationships):
"""Get set of parents GO IDs, including parents through user-specfied relationships"""
if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships:
if (
go2obj
and not hasattr(next(iter(go2obj.values())), "relationship")
or not relationships
):
return get_go2parents_isa(go2obj)
go2parents = {}
for goid_main, goterm in go2obj.items():
Expand All @@ -21,10 +26,14 @@ def get_go2parents(go2obj, relationships):
go2parents[goid_main] = parents_goids
return go2parents

# ------------------------------------------------------------------------------------

def get_go2children(go2obj, relationships):
"""Get set of children GO IDs, including children through user-specfied relationships"""
if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships:
if (
go2obj
and not hasattr(next(iter(go2obj.values())), "relationship")
or not relationships
):
return get_go2children_isa(go2obj)
go2children = {}
for goid_main, goterm in go2obj.items():
Expand All @@ -36,7 +45,7 @@ def get_go2children(go2obj, relationships):
go2children[goid_main] = children_goids
return go2children

# ------------------------------------------------------------------------------------

def get_go2parents_isa(go2obj):
"""Get set of immediate parents GO IDs"""
go2parents = {}
Expand All @@ -46,7 +55,7 @@ def get_go2parents_isa(go2obj):
go2parents[goid_main] = parents_goids
return go2parents

# ------------------------------------------------------------------------------------

def get_go2children_isa(go2obj):
"""Get set of immediate children GO IDs"""
go2children = {}
Expand All @@ -56,84 +65,96 @@ def get_go2children_isa(go2obj):
go2children[goid_main] = children_goids
return go2children

# ------------------------------------------------------------------------------------

def get_go2ancestors(terms, relationships, prt=None):
"""Get GO-to- ancestors (all parents)"""
if not relationships:
if prt is not None:
prt.write('up: is_a\n')
prt.write("up: is_a\n")
return get_id2parents(terms)
if relationships == RELATIONSHIP_SET or relationships is True:
if prt is not None:
prt.write('up: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(RELATIONSHIP_SET))))
prt.write(
"up: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET)))
)
return get_id2upper(terms)
if prt is not None:
prt.write('up: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(relationships))))
prt.write("up: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships))))
return get_id2upperselect(terms, relationships)


def get_go2descendants(terms, relationships, prt=None):
"""Get GO-to- descendants"""
if not relationships:
if prt is not None:
prt.write('down: is_a\n')
prt.write("down: is_a\n")
return get_id2children(terms)
if relationships == RELATIONSHIP_SET or relationships is True:
if prt is not None:
prt.write('down: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(RELATIONSHIP_SET))))
prt.write(
"down: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET)))
)
return get_id2lower(terms)
if prt is not None:
prt.write('down: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(relationships))))
prt.write("down: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships))))
return get_id2lowerselect(terms, relationships)

# ------------------------------------------------------------------------------------

def get_go2depth(goobjs, relationships):
"""Get depth of each object"""
if not relationships:
return {o.item_id:o.depth for o in goobjs}
return {o.item_id: o.depth for o in goobjs}
from goatools.godag.reldepth import get_go2reldepth

return get_go2reldepth(goobjs, relationships)

# ------------------------------------------------------------------------------------

def get_id2parents(objs):
"""Get all parent IDs up the hierarchy"""
id2parents = {}
for obj in objs:
_get_id2parents(id2parents, obj.item_id, obj)
return {e:es for e, es in id2parents.items() if es}
return {e: es for e, es in id2parents.items() if es}


def get_id2children(objs):
"""Get all child IDs down the hierarchy"""
id2children = {}
for obj in objs:
_get_id2children(id2children, obj.item_id, obj)
return {e:es for e, es in id2children.items() if es}
return {e: es for e, es in id2children.items() if es}


def get_id2upper(objs):
"""Get all ancestor IDs, including all parents and IDs up all relationships"""
id2upper = {}
for obj in objs:
_get_id2upper(id2upper, obj.item_id, obj)
return {e:es for e, es in id2upper.items() if es}
return {e: es for e, es in id2upper.items() if es}


def get_id2lower(objs):
"""Get all descendant IDs, including all children and IDs down all relationships"""
id2lower = {}
cache = set()
for obj in objs:
_get_id2lower(id2lower, obj.item_id, obj)
return {e:es for e, es in id2lower.items() if es}
item_id = obj.item_id
if item_id in cache:
continue
_get_id2lower(id2lower, obj.item_id, obj, cache)
return {e: es for e, es in id2lower.items() if es}


def get_id2upperselect(objs, relationship_set):
"""Get all ancestor IDs, including all parents and IDs up selected relationships"""
return IdToUpperSelect(objs, relationship_set).id2upperselect


def get_id2lowerselect(objs, relationship_set):
"""Get all descendant IDs, including all children and IDs down selected relationships"""
return IdToLowerSelect(objs, relationship_set).id2lowerselect


def get_relationship_targets(item_ids, relationships, id2rec):
"""Get item ID set of item IDs in a relationship target set"""
# Requirements to use this function:
Expand All @@ -148,7 +169,7 @@ def get_relationship_targets(item_ids, relationships, id2rec):
reltgt_objs_all.update(reltgt_objs_cur)
return reltgt_objs_all

# ------------------------------------------------------------------------------------

# pylint: disable=too-few-public-methods
class IdToUpperSelect:
"""Get all ancestor IDs, including all parents and IDs up selected relationships"""
Expand Down Expand Up @@ -178,6 +199,7 @@ def _get_id2upperselect(self, item_id, item_obj):
id2upperselect[item_id] = parent_ids
return parent_ids


class IdToLowerSelect:
"""Get all descendant IDs, including all children and IDs down selected relationships"""

Expand Down Expand Up @@ -206,7 +228,6 @@ def _get_id2lowerselect(self, item_id, item_obj):
id2lowerselect[item_id] = child_ids
return child_ids

# ------------------------------------------------------------------------------------

def _get_id2parents(id2parents, item_id, item_obj):
"""Add the parent item IDs for one item object and their parents."""
Expand All @@ -220,6 +241,7 @@ def _get_id2parents(id2parents, item_id, item_obj):
id2parents[item_id] = parent_ids
return parent_ids


def _get_id2children(id2children, item_id, item_obj):
"""Add the child item IDs for one item object and their children."""
if item_id in id2children:
Expand All @@ -232,6 +254,7 @@ def _get_id2children(id2children, item_id, item_obj):
id2children[item_id] = child_ids
return child_ids


def _get_id2upper(id2upper, item_id, item_obj):
"""Add the parent item IDs for one item object and their upper."""
if item_id in id2upper:
Expand All @@ -244,19 +267,23 @@ def _get_id2upper(id2upper, item_id, item_obj):
id2upper[item_id] = upper_ids
return upper_ids

def _get_id2lower(id2lower, item_id, item_obj):

def _get_id2lower(id2lower, item_id, item_obj, cache: set):
"""Add the lower item IDs for one item object and the objects below them."""
if item_id in id2lower:
return id2lower[item_id]
lower_ids = set()
cache.add(item_id)
for lower_obj in item_obj.get_goterms_lower():
lower_id = lower_obj.item_id
lower_ids.add(lower_id)
lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj)
if lower_id in cache:
continue
lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj, cache)
id2lower[item_id] = lower_ids
return lower_ids

# ------------------------------------------------------------------------------------

class CurNHigher:
"""Fill id2obj with item IDs in relationships."""

Expand Down
36 changes: 27 additions & 9 deletions goatools/nt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import datetime
import collections as cx


def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
"""Return a new dict of namedtuples by combining "dicts" of namedtuples or objects."""
assert len(ids) == len(set(ids)), "NOT ALL IDs ARE UNIQUE: {IDs}".format(IDs=ids)
assert len(flds) == len(set(flds)), "DUPLICATE FIELDS: {IDs}".format(
IDs=cx.Counter(flds).most_common())
IDs=cx.Counter(flds).most_common()
)
usr_id_nt = []
# 1. Instantiate namedtuple object
ntobj = cx.namedtuple("Nt", " ".join(flds))
Expand All @@ -23,6 +25,7 @@ def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
usr_id_nt.append((item_id, ntobj._make(vals)))
return cx.OrderedDict(usr_id_nt)


def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
"""Return a new list of namedtuples by combining "dicts" of namedtuples or objects."""
combined_nt_list = []
Expand All @@ -36,48 +39,61 @@ def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
combined_nt_list.append(ntobj._make(vals))
return combined_nt_list


def combine_nt_lists(lists, flds, dflt_null=""):
"""Return a new list of namedtuples by zipping "lists" of namedtuples or objects."""
combined_nt_list = []
# Check that all lists are the same length
lens = [len(lst) for lst in lists]
assert len(set(lens)) == 1, \
"LIST LENGTHS MUST BE EQUAL: {Ls}".format(Ls=" ".join(str(l) for l in lens))
assert len(set(lens)) == 1, "LIST LENGTHS MUST BE EQUAL: {Ls}".format(
Ls=" ".join(str(l) for l in lens)
)
# 1. Instantiate namedtuple object
ntobj = cx.namedtuple("Nt", " ".join(flds))
# 2. Loop through zipped list
for lst0_lstn in zip(*lists):
# 2a. Combine various namedtuples into a single namedtuple
combined_nt_list.append(ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null)))
combined_nt_list.append(
ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null))
)
return combined_nt_list


def wr_py_nts(fout_py, nts, docstring=None, varname="nts"):
"""Save namedtuples into a Python module."""
if nts:
with open(fout_py, 'w') as prt:
with open(fout_py, "w") as prt:
prt.write('"""{DOCSTRING}"""\n\n'.format(DOCSTRING=docstring))
prt.write("# Created: {DATE}\n".format(DATE=str(datetime.date.today())))
prt_nts(prt, nts, varname)
sys.stdout.write(" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py))
sys.stdout.write(
" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py)
)

def prt_nts(prt, nts, varname, spc=' '):

def prt_nts(prt, nts, varname, spc=" "):
"""Print namedtuples into a Python module."""
first_nt = nts[0]
nt_name = type(first_nt).__name__
prt.write("import collections as cx\n\n")
prt.write("import numpy as np\n\n")
prt.write("NT_FIELDS = [\n")
for fld in first_nt._fields:
prt.write('{SPC}"{F}",\n'.format(SPC=spc, F=fld))
prt.write("]\n\n")
prt.write('{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
NtName=nt_name))
prt.write(
'{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
NtName=nt_name
)
)
prt.write("# {N:,} items\n".format(N=len(nts)))
prt.write("# pylint: disable=line-too-long\n")
prt.write("{VARNAME} = [\n".format(VARNAME=varname))
for ntup in nts:
prt.write("{SPC}{NT},\n".format(SPC=spc, NT=ntup))
prt.write("]\n")


def get_unique_fields(fld_lists):
"""Get unique namedtuple fields, despite potential duplicates in lists of fields."""
flds = []
Expand All @@ -93,6 +109,7 @@ def get_unique_fields(fld_lists):
assert len(flds) == len(fld_set)
return flds


# -- Internal methods ----------------------------------------------------------------
def _combine_nt_vals(lst0_lstn, flds, dflt_null):
"""Given a list of lists of nts, return a single namedtuple."""
Expand All @@ -110,4 +127,5 @@ def _combine_nt_vals(lst0_lstn, flds, dflt_null):
vals.append(dflt_null)
return vals


# Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved.
3 changes: 3 additions & 0 deletions tests/test_dcnt_r01.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sys
import timeit
import numpy as np
import pytest

from numpy.random import shuffle
from scipy import stats

Expand All @@ -14,6 +16,7 @@
from goatools.obo_parser import GODag


@pytest.mark.skip(reason="Latest obo (`releases/2024-06-10`) is not DAG")
def test_go_pools():
"""Print a comparison of GO terms from different species in two different comparisons."""
objr = _Run()
Expand Down

0 comments on commit d0ba53e

Please sign in to comment.