diff --git a/graphtools/__init__.py b/graphtools/__init__.py index 8fc8a50..7384afc 100644 --- a/graphtools/__init__.py +++ b/graphtools/__init__.py @@ -1,2 +1,2 @@ -from .api import Graph, from_igraph +from .api import Graph, from_igraph, read_pickle from .version import __version__ diff --git a/graphtools/api.py b/graphtools/api.py index 5d48316..30a6f26 100644 --- a/graphtools/api.py +++ b/graphtools/api.py @@ -2,6 +2,8 @@ import warnings import tasklogger from scipy import sparse +import pickle +import pygsp from . import base from . import graphs @@ -283,3 +285,22 @@ def from_igraph(G, attribute="weight", **kwargs): K = G.get_adjacency(attribute=None).data return Graph(sparse.coo_matrix(K), precomputed='adjacency', **kwargs) + + +def read_pickle(path): + """Load pickled Graphtools object (or any object) from file. + + Parameters + ---------- + path : str + File path where the pickled object will be loaded. + """ + with open(path, 'rb') as f: + G = pickle.load(f) + + if not isinstance(G, base.BaseGraph): + warnings.warn( + 'Returning object that is not a graphtools.base.BaseGraph') + elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str): + G.logger = pygsp.utils.build_logger(G.logger) + return G diff --git a/graphtools/base.py b/graphtools/base.py index 2dc8736..d0e1fa4 100644 --- a/graphtools/base.py +++ b/graphtools/base.py @@ -3,14 +3,15 @@ import numpy as np import abc import pygsp -from sklearn.utils.fixes import signature +from inspect import signature from sklearn.decomposition import PCA, TruncatedSVD from sklearn.preprocessing import normalize -from sklearn.utils.graph import graph_shortest_path from scipy import sparse import warnings import numbers import tasklogger +import pickle +import sys try: import pandas as pd @@ -106,10 +107,10 @@ class Data(Base): def __init__(self, data, n_pca=None, random_state=None, **kwargs): self._check_data(data) - if n_pca is not None and data.shape[1] <= n_pca: + if n_pca is not None and np.min(data.shape) <= n_pca: warnings.warn("Cannot perform PCA to {} dimensions on " - "data with {} dimensions".format(n_pca, - data.shape[1]), + "data with min(n_samples, n_features) = {}".format( + n_pca, np.min(data.shape)), RuntimeWarning) n_pca = None try: @@ -316,7 +317,7 @@ class BaseGraph(with_metaclass(abc.ABCMeta, Base)): 'theta' : min-max 'none' : no symmetrization - theta: float (default: 0.5) + theta: float (default: 1) Min-max symmetrization constant. K = `theta * min(K, K.T) + (1 - theta) * max(K, K.T)` @@ -385,7 +386,7 @@ def _check_symmetrization(self, kernel_symm, theta): if theta is None: warnings.warn("kernel_symm='theta' but theta not given. " "Defaulting to theta=0.5.") - self.theta = theta = 0.5 + self.theta = theta = 1 elif not isinstance(theta, numbers.Number) or \ theta < 0 or theta > 1: raise ValueError("theta {} not recognized. Expected " @@ -636,6 +637,23 @@ def to_igraph(self, attribute="weight", **kwargs): return ig.Graph.Weighted_Adjacency(utils.to_dense(W).tolist(), attr=attribute, **kwargs) + def to_pickle(self, path): + """Save the current Graph to a pickle. + + Parameters + ---------- + path : str + File path where the pickled object will be stored. + """ + if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph): + # python 3.5, 3.6 + logger = self.logger + self.logger = logger.name + with open(path, 'wb') as f: + pickle.dump(self, f) + if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph): + self.logger = logger + class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)): """Interface between BaseGraph and PyGSP. diff --git a/graphtools/graphs.py b/graphtools/graphs.py index 298bc7e..157da60 100644 --- a/graphtools/graphs.py +++ b/graphtools/graphs.py @@ -14,8 +14,6 @@ import tasklogger from .utils import (set_diagonal, - elementwise_minimum, - elementwise_maximum, set_submatrix) from .base import DataGraph, PyGSPGraph @@ -245,7 +243,7 @@ def _check_duplicates(self, distances, indices): "Detected zero distance between {} pairs of samples. " "Consider removing duplicates to avoid errors in " "downstream processing.".format( - np.sum(np.sum(distances[:, 1:]))), + np.sum(np.sum(distances[:, 1:] == 0))), RuntimeWarning) def build_kernel_to_data(self, Y, knn=None, bandwidth=None, diff --git a/graphtools/version.py b/graphtools/version.py index 5becc17..6849410 100644 --- a/graphtools/version.py +++ b/graphtools/version.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.1.0" diff --git a/test/test_api.py b/test/test_api.py index 5379cf5..e5ae0d8 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -9,6 +9,8 @@ import igraph import numpy as np import graphtools +import tempfile +import os def test_from_igraph(): @@ -81,6 +83,56 @@ def test_to_igraph(): attribute="weight").data) == G.W) +def test_pickle_io_knngraph(): + G = build_graph(data, knn=5, decay=None) + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + + +def test_pickle_io_traditionalgraph(): + G = build_graph(data, knn=5, decay=10, thresh=0) + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + + +def test_pickle_io_landmarkgraph(): + G = build_graph(data, knn=5, decay=None, + n_landmark=data.shape[0] // 2) + L = G.landmark_op + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + np.testing.assert_array_equal(L, G_prime._landmark_op) + + +def test_pickle_io_pygspgraph(): + G = build_graph(data, knn=5, decay=None, use_pygsp=True) + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + assert G_prime.logger.name == G.logger.name + + +@warns(UserWarning) +def test_pickle_bad_pickle(): + import pickle + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + with open(path, 'wb') as f: + pickle.dump('hello world', f) + G = graphtools.read_pickle(path) + + @warns(UserWarning) def test_to_pygsp_invalid_precomputed(): G = build_graph(data) diff --git a/test/test_data.py b/test/test_data.py index dfa0889..aed139c 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -44,6 +44,12 @@ def test_too_many_n_pca(): build_graph(data, n_pca=data.shape[1]) +@warns(RuntimeWarning) +def test_too_many_n_pca(): + build_graph(data[:data.shape[1] - 1], + n_pca=data.shape[1] - 1) + + @warns(RuntimeWarning) def test_precomputed_with_pca(): build_graph(squareform(pdist(data)),