Skip to content

Commit

Permalink
Format utils/proofreading with pyink.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558616937
  • Loading branch information
mjanusz authored and copybara-github committed Aug 20, 2023
1 parent a37f069 commit 95da2dc
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions ffn/utils/proofreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Base:
The segmentation volume needs to be called `seg`.
"""

def __init__(self, num_to_prefetch=10, locations=None, objects=None):
def __init__(self, num_to_prefetch: int = 10, locations=None, objects=None):
self.viewer = neuroglancer.Viewer()
self.num_to_prefetch = num_to_prefetch

Expand Down Expand Up @@ -81,7 +81,7 @@ def update_segments(self, segments, loc=None, layer='seg'):
l.equivalences.clear()
else:
l.equivalences.clear()
for a in self.todo[self.index:self.index + self.batch]:
for a in self.todo[self.index : self.index + self.batch]:
a = [aa[layer] for aa in a]
l.equivalences.union(*a)

Expand Down Expand Up @@ -120,7 +120,10 @@ def list_segments(self, index=None, layer='seg'):
return list(
set(
itertools.chain(
*[x[layer] for x in self.todo[index:index + self.batch]])))
*[x[layer] for x in self.todo[index : index + self.batch]]
)
)
)

def custom_msg(self):
return ''
Expand All @@ -133,8 +136,10 @@ def update_batch(self, update=True):

for layer in self.managed_layers:
self.update_segments(self.list_segments(layer=layer), loc, layer=layer)
self.update_msg('index:%d/%d batch:%d %s' %
(self.index, len(self.todo), self.batch, self.custom_msg()))
self.update_msg(
'index:%d/%d batch:%d %s'
% (self.index, len(self.todo), self.batch, self.custom_msg())
)

def prefetch(self):
prefetch_states = []
Expand All @@ -145,7 +150,8 @@ def prefetch(self):
prefetch_state = copy.deepcopy(self.viewer.state)
for layer in self.managed_layers:
prefetch_state.layers[layer].segments = self.list_segments(
idx, layer=layer)
idx, layer=layer
)
prefetch_state.layout = '3d'
if self.locations is not None:
prefetch_state.position = self.locations[idx]
Expand Down Expand Up @@ -180,15 +186,18 @@ def __init__(self, objects, bad, num_to_prefetch=10, locations=None):
the current object if batch == 1.
"""
super().__init__(
num_to_prefetch=num_to_prefetch, locations=locations, objects=objects)
num_to_prefetch=num_to_prefetch, locations=locations, objects=objects
)
self.bad = bad

self.viewer.actions.add('next-batch', lambda s: self.next_batch())
self.viewer.actions.add('prev-batch', lambda s: self.prev_batch())
self.viewer.actions.add('dec-batch', lambda s: self.batch_dec())
self.viewer.actions.add('inc-batch', lambda s: self.batch_inc())
self.viewer.actions.add('mark-bad', lambda s: self.mark_bad())
self.viewer.actions.add('mark-removed-bad', lambda s: self.mark_removed_bad())
self.viewer.actions.add(
'mark-removed-bad', lambda s: self.mark_removed_bad()
)
self.viewer.actions.add('toggle-equiv', lambda s: self.toggle_equiv())

with self.viewer.config_state.txn() as s:
Expand Down Expand Up @@ -216,15 +225,15 @@ def mark_bad(self):
else:
self.bad.add(frozenset(sids))

self.update_msg('marked bad: %r' % (sids, ))
self.update_msg('marked bad: %r' % (sids,))
self.next_batch()

def mark_removed_bad(self):
original = set(self.list_segments())
new_bad = original - set(self.viewer.state.layers['seg'].segments)
if new_bad:
self.bad |= new_bad
self.update_msg('marked bad: %r' % (new_bad, ))
self.update_msg('marked bad: %r' % (new_bad,))


class ObjectClassification(Base):
Expand All @@ -239,7 +248,8 @@ def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None):
num_to_prefetch: number of `objects` to prefetch
"""
super().__init__(
num_to_prefetch=num_to_prefetch, locations=locations, objects=objects)
num_to_prefetch=num_to_prefetch, locations=locations, objects=objects
)

self.results = defaultdict(set) # class -> ids

Expand All @@ -249,7 +259,8 @@ def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None):

for key, cls in key_to_class.items():
self.viewer.actions.add(
'classify-%s' % cls, lambda s, cls=cls: self.classify(cls))
'classify-%s' % cls, lambda s, cls=cls: self.classify(cls)
)

with self.viewer.config_state.txn() as s:
for key, cls in key_to_class.items():
Expand Down Expand Up @@ -332,7 +343,8 @@ def __init__(self, graph, objects, bad, num_to_prefetch=0):

with self.viewer.txn() as s:
s.layers['split'] = neuroglancer.SegmentationLayer(
source=s.layers['seg'].source)
source=s.layers['seg'].source
)
s.layers['split'].visible = False

def merge_segments(self):
Expand All @@ -341,7 +353,7 @@ def merge_segments(self):

def update_split(self):
s = copy.deepcopy(self.viewer.state)
s.layers['split'].segments = list(self.split_path)[:self.split_index]
s.layers['split'].segments = list(self.split_path)[: self.split_index]
self.viewer.set_state(s)

def inc_split(self):
Expand All @@ -363,7 +375,7 @@ def add_ccs(self):
self.sem.release()

def accept_split(self):
edge = self.split_path[self.split_index - 1:self.split_index + 1]
edge = self.split_path[self.split_index - 1 : self.split_index + 1]
if len(edge) < 2:
return

Expand All @@ -380,11 +392,11 @@ def clear_splits(self):
self.viewer.set_state(s)

def start_split(self):
self.split_path = nx.shortest_path(self.graph, self.split_objects[0],
self.split_objects[1])
self.split_path = nx.shortest_path(
self.graph, self.split_objects[0], self.split_objects[1]
)
self.split_index = 1
self.update_msg(
'splitting: %s' % ('-'.join(str(x) for x in self.split_path)))
self.update_msg('splitting: %s' % '-'.join(str(x) for x in self.split_path))

s = copy.deepcopy(self.viewer.state)
s.layers['seg'].visible = False
Expand All @@ -395,8 +407,7 @@ def start_split(self):
def add_split(self, s):
if len(self.split_objects) < 2:
self.split_objects.append(s.selected_values['seg'].value)
self.update_msg(
'split: %s' % (':'.join(str(x) for x in self.split_objects)))
self.update_msg('split: %s' % ':'.join(str(x) for x in self.split_objects))

if len(self.split_objects) == 2:
self.start_split()
Expand All @@ -412,5 +423,5 @@ def mark_bad(self):
else:
self.bad.add(frozenset(sids))

self.update_msg('marked bad: %r' % (sids, ))
self.update_msg('marked bad: %r' % (sids,))
self.next_batch()

0 comments on commit 95da2dc

Please sign in to comment.