Skip to content

Commit

Permalink
Add type annotations for the object lists.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567287094
  • Loading branch information
mjanusz authored and copybara-github committed Sep 21, 2023
1 parent 2e6d981 commit 82353df
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions ffn/utils/proofreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
import copy
import itertools
import threading
from typing import Iterable, Mapping, Optional

import networkx as nx
import neuroglancer


Point = tuple[int, int, int]
ObjectItem = int | Iterable[int] | Mapping[str, Iterable[int]]


class Base:
"""Base class for proofreading workflows.
Expand All @@ -34,12 +39,19 @@ class Base:
The segmentation volume needs to be called `seg`.
"""

def __init__(self, num_to_prefetch: int = 10, locations=None, objects=None):
def __init__(
self,
num_to_prefetch: int = 10,
locations: Optional[Iterable[Point]] = None,
objects: Optional[Iterable[ObjectItem]] = None,
):
self.viewer = neuroglancer.Viewer()
self.num_to_prefetch = num_to_prefetch

self.managed_layers = set(['seg'])
self.todo = [] # items are maps from layer name to lists of segment IDs
self.todo: list[dict[str, list[int]]] = (
[]
) # items are maps from layer name to lists of segment IDs
if objects is not None:
self._set_todo(objects)

Expand All @@ -55,7 +67,7 @@ def __init__(self, num_to_prefetch: int = 10, locations=None, objects=None):

self.set_init_state()

def _set_todo(self, objects):
def _set_todo(self, objects: Iterable[ObjectItem]):
for o in objects:
if isinstance(o, collections.abc.Mapping):
self.todo.append(o)
Expand Down Expand Up @@ -239,7 +251,13 @@ def mark_removed_bad(self):
class ObjectClassification(Base):
"""Base class for object classification."""

def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None):
def __init__(
self,
objects: Iterable[ObjectItem],
key_to_class,
num_to_prefetch: int = 10,
locations=None,
):
"""Constructor.
Args:
Expand Down Expand Up @@ -309,7 +327,9 @@ class GraphUpdater(Base):
(according to the current state of the agglomeraton graph).
"""

def __init__(self, graph, objects, bad, num_to_prefetch=0):
def __init__(
self, graph, objects: Iterable[ObjectItem], bad, num_to_prefetch: int = 0
):
super().__init__(objects=objects, num_to_prefetch=num_to_prefetch)
self.graph = graph
self.split_objects = []
Expand Down

0 comments on commit 82353df

Please sign in to comment.