Skip to content

Commit

Permalink
Add type annotations for the object lists.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567287094
  • Loading branch information
mjanusz authored and copybara-github committed Sep 21, 2023
1 parent 2e6d981 commit 337d4a9
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions ffn/utils/proofreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
import copy
import itertools
import threading
from typing import Iterable, Optional

import networkx as nx
import neuroglancer


Point = tuple[int, int, int]
ObjectItem = int | Iterable[int] | dict[str, Iterable[int]]


class Base:
"""Base class for proofreading workflows.
Expand All @@ -34,12 +39,19 @@ class Base:
The segmentation volume needs to be called `seg`.
"""

def __init__(self, num_to_prefetch: int = 10, locations=None, objects=None):
def __init__(
self,
num_to_prefetch: int = 10,
locations: Optional[Iterable[Point]] = None,
objects: Optional[Iterable[ObjectItem]] = None,
):
self.viewer = neuroglancer.Viewer()
self.num_to_prefetch = num_to_prefetch

self.managed_layers = set(['seg'])
self.todo = [] # items are maps from layer name to lists of segment IDs
self.todo: list[dict[str, list[int]]] = (
[]
) # items are maps from layer name to lists of segment IDs
if objects is not None:
self._set_todo(objects)

Expand All @@ -49,19 +61,19 @@ def __init__(self, num_to_prefetch: int = 10, locations=None, objects=None):

if locations is not None:
self.locations = list(locations)
assert len(self.todo) == len(locations)
assert len(self.todo) == len(self.locations)
else:
self.locations = None

self.set_init_state()

def _set_todo(self, objects):
def _set_todo(self, objects: Iterable[ObjectItem]):
for o in objects:
if isinstance(o, collections.abc.Mapping):
self.todo.append(o)
self.todo.append({k: list(v) for k, v in o.items()})
self.managed_layers |= set(o.keys())
elif isinstance(o, collections.abc.Iterable):
self.todo.append({'seg': o})
self.todo.append({'seg': list(o)})
else:
self.todo.append({'seg': [o]})

Expand Down Expand Up @@ -239,7 +251,13 @@ def mark_removed_bad(self):
class ObjectClassification(Base):
"""Base class for object classification."""

def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None):
def __init__(
self,
objects: Iterable[ObjectItem],
key_to_class,
num_to_prefetch: int = 10,
locations=None,
):
"""Constructor.
Args:
Expand Down Expand Up @@ -309,7 +327,9 @@ class GraphUpdater(Base):
(according to the current state of the agglomeraton graph).
"""

def __init__(self, graph, objects, bad, num_to_prefetch=0):
def __init__(
self, graph, objects: Iterable[ObjectItem], bad, num_to_prefetch: int = 0
):
super().__init__(objects=objects, num_to_prefetch=num_to_prefetch)
self.graph = graph
self.split_objects = []
Expand Down

0 comments on commit 337d4a9

Please sign in to comment.