Skip to content

Commit

Permalink
Merge pull request #8 from tlambert03/ruff
Browse files Browse the repository at this point in the history
use ruff, fix mypy
  • Loading branch information
tlambert03 authored Nov 7, 2024
2 parents 732a77f + ef4a9fa commit c432f40
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 80 deletions.
17 changes: 0 additions & 17 deletions .github/workflows/black.yaml

This file was deleted.

18 changes: 0 additions & 18 deletions .github/workflows/mypy.yaml

This file was deleted.

13 changes: 10 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
name: Test

on:
push:
on: [push, pull_request]

jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Run pre-commit
run: pipx run pre-commit run --all-files

test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
12 changes: 7 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 23.1.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
hooks:
- id: black
- id: ruff
args: [--fix, --unsafe-fixes]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
rev: v1.13.0
hooks:
- id: mypy
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "spatial_graph"
description = "A spatial graph datastructure for python."
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
]
Expand All @@ -18,13 +18,14 @@ authors = [
dynamic = ["version"]
dependencies = [
"witty @ git+https://github.com/funkelab/[email protected]",
"cheetah3"
"cheetah3",
"numpy",
]

[project.optional-dependencies]
dev = [
'pytest',
'black',
'ruff',
'mypy',
'pdoc',
'pre-commit',
Expand All @@ -38,3 +39,7 @@ repository = "https://github.com/funkelab/spatial_graph"
[tool.setuptools]
packages = ["spatial_graph"]
package-data = { "spatial_graph" = ["*.pyi"] }

[tool.ruff]
target-version = "py39"
src = ["spatial_graph"]
3 changes: 3 additions & 0 deletions spatial_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from .rtree import LineRTree
from .graph import Graph
from .spatial_graph import SpatialGraph


__all__ = ["PointRTree", "LineRTree", "Graph", "SpatialGraph"]
2 changes: 2 additions & 0 deletions spatial_graph/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .graph import Graph

__all__ = ["Graph"]
6 changes: 3 additions & 3 deletions spatial_graph/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, graph, nodes):
_ = len(nodes)
# if so, convert to ndarray
nodes = np.array(nodes, dtype=graph.node_dtype)
except Exception as e:
except Exception:
# must be a single node
for name in graph.node_attr_dtypes.keys():
super().__setattr__(
Expand Down Expand Up @@ -152,10 +152,10 @@ def __init__(self, graph, edges):
# edges should be an iteratable
try:
# does it have a length?
num_edges = len(edges)
len(edges)
# case 2 and 3
edges = np.array(edges, dtype=graph.node_dtype)
except Exception as e:
except Exception:
raise RuntimeError(f"Can not handle edges type {type(edges)}")

if isinstance(edges, np.ndarray):
Expand Down
3 changes: 3 additions & 0 deletions spatial_graph/rtree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from .point_rtree import PointRTree
from .line_rtree import LineRTree


__all__ = ["PointRTree", "LineRTree"]
8 changes: 4 additions & 4 deletions spatial_graph/rtree/line_rtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@


class LineRTree(RTree):
pyx_item_t_declaration = f"""
pyx_item_t_declaration = """
cdef struct item_t:
item_base_t u
item_base_t v
bool corner_mask[DIMS]
"""

c_item_t_declaration = f"""
typedef struct item_t {{
c_item_t_declaration = """
typedef struct item_t {
item_base_t u;
item_base_t v;
bool corner_mask[DIMS];
}} item_t;
} item_t;
"""

c_converter_functions = """
Expand Down
13 changes: 6 additions & 7 deletions spatial_graph/rtree/rtree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import ClassVar
import witty
import numpy as np
from Cheetah.Template import Template
Expand Down Expand Up @@ -50,17 +51,17 @@ class RTree:
"""

# overwrite in subclasses for custom item_t structures
pyx_item_t_declaration = None
c_item_t_declaration = None
pyx_item_t_declaration: ClassVar[str] = ""
c_item_t_declaration: ClassVar[str] = ""

# overwrite in subclasses for custom converters
c_converter_functions = None
c_converter_functions: ClassVar[str] = ""

# overwrite in subclasses for custom item comparison code
c_equal_function = None
c_equal_function: ClassVar[str] = ""

# overwrite in subclasses for custom distance computation
c_distance_function = None
c_distance_function: ClassVar[str] = ""

def __new__(
cls,
Expand All @@ -69,8 +70,6 @@ def __new__(
dims,
):
item_dtype = DType(item_dtype)
item_is_array = item_dtype.is_array
item_size = item_dtype.size
coord_dtype = DType(coord_dtype)

############################################
Expand Down
3 changes: 1 addition & 2 deletions spatial_graph/spatial_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def remove_nodes(self, nodes):
edges = self.edges_by_nodes(nodes)
else:
edges = np.concatenate(
self.in_edges_by_nodes(nodes),
self.out_edges_by_nodes(nodes)
self.in_edges_by_nodes(nodes), self.out_edges_by_nodes(nodes)
)
positions_u = getattr(self.node_attrs[edges[:, 0]], self.position_attr)
positions_v = getattr(self.node_attrs[edges[:, 1]], self.position_attr)
Expand Down
43 changes: 26 additions & 17 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@pytest.mark.parametrize("edge_attr_dtypes", edge_attr_dtypes)
@pytest.mark.parametrize("directed", [True, False])
def test_construction(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed):
graph = sg.Graph(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed)
sg.Graph(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed)


@pytest.mark.parametrize("directed", [True, False])
Expand Down Expand Up @@ -119,12 +119,16 @@ def test_attribute_modification():
np.array([1, 2, 3, 4, 5], dtype="uint64"),
attr1=np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype="double"),
attr2=np.array([1, 2, 3, 4, 5], dtype="int"),
attr3=np.array([
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3]], dtype="float32")
attr3=np.array(
[
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
],
dtype="float32",
),
)

graph.add_edges(
Expand Down Expand Up @@ -165,10 +169,14 @@ def test_attribute_modification():
assert graph.node_attrs[3].attr2 == 60
assert graph.node_attrs[4].attr2 == 80

graph.node_attrs[[2, 3, 4]].attr3 += np.array([
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]], dtype="float32")
graph.node_attrs[[2, 3, 4]].attr3 += np.array(
[
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
],
dtype="float32",
)
np.testing.assert_array_almost_equal(graph.node_attrs[2].attr3, [1.3, 1.6, 1.9])
np.testing.assert_array_almost_equal(graph.node_attrs[3].attr3, [2.3, 2.6, 2.9])
np.testing.assert_array_almost_equal(graph.node_attrs[4].attr3, [3.3, 3.6, 3.9])
Expand All @@ -178,25 +186,26 @@ def test_attribute_modification():
graph.edge_attrs[[[1, 2], [5, 1]]].attr1,
[
[1, 2, 3, 4],
[3, 4, 5, 6]
]
[3, 4, 5, 6],
],
)
graph.edge_attrs[[[1, 2], [5, 1]]].attr1 = np.array(
[
[11, 22, 33, 44],
[30, 40, 50, 60]
[30, 40, 50, 60],
],
dtype="int"
dtype="int",
)
np.testing.assert_array_equal(
graph.edge_attrs[[[1, 2], [3, 4], [5, 1]]].attr1,
[
[11, 22, 33, 44],
[2, 3, 4, 5],
[30, 40, 50, 60]
]
[30, 40, 50, 60],
],
)


def test_missing_nodes_edges():
graph = sg.Graph(
"uint64", {"node_attr": "float32"}, {"edge_attr": "float32"}, directed=False
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def test_line_rtree_nearest():
)
np.testing.assert_almost_equal(distances[0], 0.5)

def test_line_rtree_delete():

def test_line_rtree_delete():
line_rtree = LineRTree("uint64[2]", "double", 2)

line_rtree.insert_lines(
Expand Down

0 comments on commit c432f40

Please sign in to comment.