Skip to content

Commit

Permalink
soft error threshold, hopeful fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Aug 7, 2024
1 parent 3283161 commit a4574fb
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 37 deletions.
62 changes: 37 additions & 25 deletions glue/sample/src/sinter/_collection/_collection_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import collections
import contextlib
import math
import multiprocessing
import os
import pathlib
import queue
import sys
import tempfile
from typing import Any, Optional, List, Dict, Iterable, Callable, Tuple
from typing import Union
Expand All @@ -23,26 +23,26 @@ def __init__(self, worker_id: int, *, cpu_pin: Optional[int] = None):
self.process: Optional[multiprocessing.Process] = None
self.input_queue: Optional[multiprocessing.Queue[Tuple[str, Any]]] = None
self.assigned_work_key: Any = None
self.assigned_shots: int = 0
self.assigned_shots_remote: int = 0
self.asked_to_drop_shots: int = 0
self.cpu_pin = cpu_pin

def send_message(self, message: Any):
self.input_queue.put(message)

def ask_to_return_all_shots(self):
if self.asked_to_drop_shots == 0 and self.assigned_shots > 0:
if self.asked_to_drop_shots == 0 and self.assigned_shots_remote > 0:
self.send_message((
'return_shots',
(
self.assigned_work_key,
self.assigned_shots - self.asked_to_drop_shots,
self.assigned_shots_remote,
),
))
self.asked_to_drop_shots = self.assigned_shots
self.asked_to_drop_shots = self.assigned_shots_remote

def has_returned_all_shots(self) -> bool:
return self.assigned_shots == 0 and self.asked_to_drop_shots == 0
return self.assigned_shots_remote == 0 and self.asked_to_drop_shots == 0

def is_available_to_reassign(self) -> bool:
return self.assigned_work_key is None
Expand All @@ -56,7 +56,8 @@ def __init__(self, *, partial_task: Task, strong_id: str, shots_left: int, error
self.errors_left = errors_left
self.shots_unassigned = shots_left
self.shot_return_requests = 0
self.workers_assigned = []
self.assigned_soft_error_flush_threshold: int = errors_left
self.workers_assigned: list[int] = []

def is_completed(self) -> bool:
return self.shots_left <= 0 or self.errors_left <= 0
Expand Down Expand Up @@ -259,7 +260,7 @@ def _handle_task_progress(self, task_id: Any):
del self.task_states[task_id]
for worker_id in task_state.workers_assigned:
w = self.worker_states[worker_id]
assert w.assigned_shots == 0
assert w.assigned_shots_remote == 0
assert w.asked_to_drop_shots == 0
w.assigned_work_key = None
self._distribute_work()
Expand All @@ -277,7 +278,7 @@ def state_summary(self) -> str:
for worker_id, worker in enumerate(self.worker_states):
lines.append(f'worker {worker_id}:'
f' asked_to_drop_shots={worker.asked_to_drop_shots}'
f' assigned_shots={worker.assigned_shots}'
f' assigned_shots_remote={worker.assigned_shots_remote}'
f' assigned_work_key={worker.assigned_work_key}')
for task in self.task_states.values():
lines.append(f'task {task.strong_id=}:\n'
Expand All @@ -302,11 +303,8 @@ def process_message(self) -> bool:
assert isinstance(anon_stat, AnonTaskStats)
assert worker_state.assigned_work_key == task_strong_id
task_state = self.task_states[task_strong_id]
worker_state.assigned_shots -= anon_stat.shots
if worker_state.assigned_shots < 0:
# Overachieving sampler did extra shots.
task_state.shots_unassigned += worker_state.assigned_shots
worker_state.assigned_shots -= worker_state.assigned_shots

worker_state.assigned_shots_remote -= anon_stat.shots
task_state.shots_left -= anon_stat.shots
if self.custom_error_count_key is None:
task_state.errors_left -= anon_stat.errors
Expand Down Expand Up @@ -348,8 +346,9 @@ def process_message(self) -> bool:
worker_state.asked_to_drop_shots = 0
worker_state.asked_to_drop_errors = 0
task_state.shots_unassigned += shots_returned
worker_state.assigned_shots -= shots_returned
assert worker_state.assigned_shots >= 0
worker_state.assigned_shots_remote -= shots_returned
if worker_state.assigned_shots_remote < 0:
worker_state.assigned_shots_remote = 0
self._handle_task_progress(task_key)

elif message_type == 'stopped_due_to_exception':
Expand Down Expand Up @@ -405,7 +404,7 @@ def _distribute_unassigned_workers_to_jobs(self):
worker_state.assigned_work_key = task_state.strong_id
worker_state.send_message((
'change_job',
(task_state.partial_task, CollectionOptions(max_errors=task_state.errors_left)),
(task_state.partial_task, CollectionOptions(max_errors=task_state.errors_left), task_state.assigned_soft_error_flush_threshold),
))

def _distribute_unassigned_work_to_workers_within_a_job(self, task_state: _ManagedTaskState):
Expand All @@ -416,14 +415,14 @@ def _distribute_unassigned_work_to_workers_within_a_job(self, task_state: _Manag
expected_shots_per_worker = (task_state.shots_left + num_task_workers - 1) // num_task_workers

# Give unassigned shots to idle workers.
for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots):
for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots_remote):
worker_state = self.worker_states[worker_id]
if worker_state.assigned_shots < expected_shots_per_worker:
shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots,
if worker_state.assigned_shots_remote < expected_shots_per_worker:
shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots_remote,
task_state.shots_unassigned)
if shots_to_assign > 0:
task_state.shots_unassigned -= shots_to_assign
worker_state.assigned_shots += shots_to_assign
worker_state.assigned_shots_remote += shots_to_assign
worker_state.send_message((
'accept_shots',
(task_state.strong_id, shots_to_assign),
Expand Down Expand Up @@ -502,22 +501,35 @@ def status_message(self) -> str:
lines.append(' ... (' + str(len(skipped_lines)) + ' more tasks) ...')
return f'{tasks_left} tasks left:\n' + '\n'.join(lines)

def _update_soft_error_threshold_for_a_job(self, task_state: _ManagedTaskState):
if task_state.errors_left <= len(task_state.workers_assigned):
desired_threshold = 1
elif task_state.errors_left <= task_state.assigned_soft_error_flush_threshold * self.num_workers:
desired_threshold = max(1, math.ceil(task_state.errors_left * 0.5 / self.num_workers))
else:
return

if task_state.assigned_soft_error_flush_threshold != desired_threshold:
task_state.assigned_soft_error_flush_threshold = desired_threshold
for wid in task_state.workers_assigned:
self.worker_states[wid].send_message(('set_soft_error_flush_threshold', desired_threshold))

def _take_work_if_unsatisfied_workers_within_a_job(self, task_state: _ManagedTaskState):
if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0:
return

if all(self.worker_states[w].assigned_shots for w in task_state.workers_assigned):
if all(self.worker_states[w].assigned_shots_remote for w in task_state.workers_assigned):
return

w = len(task_state.workers_assigned)
expected_shots_per_worker = (task_state.shots_left + w - 1) // w

# There are idle workers that couldn't be given any shots. Take shots from other workers.
for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots, reverse=True):
for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots_remote, reverse=True):
worker_state = self.worker_states[worker_id]
if worker_state.asked_to_drop_shots or worker_state.assigned_shots <= expected_shots_per_worker:
if worker_state.asked_to_drop_shots or worker_state.assigned_shots_remote <= expected_shots_per_worker:
continue
shots_to_take = worker_state.assigned_shots - expected_shots_per_worker
shots_to_take = worker_state.assigned_shots_remote - expected_shots_per_worker
assert shots_to_take > 0
worker_state.asked_to_drop_shots = shots_to_take
task_state.shot_return_requests += 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import pathlib
from typing import Optional, TYPE_CHECKING
from typing import Union

from sinter._decoding import Decoder, Sampler
from sinter._decoding import Sampler
from sinter._collection._collection_worker_state import CollectionWorkerState

if TYPE_CHECKING:
Expand Down
33 changes: 24 additions & 9 deletions glue/sample/src/sinter/_collection/_collection_worker_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,11 @@ def __init__(
self.current_task_shots_left: int = 0
self.unflushed_results: AnonTaskStats = AnonTaskStats()
self.last_flush_message_time = time.monotonic()
self.soft_error_flush_threshold: int = 1

def flush_results(self):
if self.unflushed_results.shots > 0:
self.last_flush_message_time = time.monotonic()
if self.current_task_shots_left < 0:
self.current_task_shots_left = 0
self.out.put((
'flushed_results',
self.worker_id,
Expand All @@ -98,6 +97,8 @@ def return_shots(self, *, requested_shots: int):
self.current_task_shots_left -= returned_shots
if self.current_task_shots_left <= 0:
self.flush_results()
if self.current_task_shots_left < 0:
self.current_task_shots_left = 0
self.out.put((
'returned_shots',
self.worker_id,
Expand Down Expand Up @@ -150,11 +151,16 @@ def process_messages(self) -> int:
self.compute_strong_id(new_task=message_body)

elif message_type == 'change_job':
new_task, new_collection_options, soft_error_flush_threshold = message_body
self.cur_flush_period = 0.01
new_task, new_collection_options = message_body
self.soft_error_flush_threshold = soft_error_flush_threshold
assert isinstance(new_task, Task)
self.change_job(new_task=new_task, new_collection_options=new_collection_options)

elif message_type == 'set_soft_error_flush_threshold':
soft_error_flush_threshold = message_body
self.soft_error_flush_threshold = soft_error_flush_threshold

elif message_type == 'accept_shots':
job_key, shots_delta = message_body
assert isinstance(shots_delta, int)
Expand All @@ -170,9 +176,15 @@ def process_messages(self) -> int:
else:
raise NotImplementedError(f'{message_type=}')

def num_unflushed_errors(self) -> int:
if self.custom_error_count_key is not None:
return self.unflushed_results.custom_counts[self.custom_error_count_key]
return self.unflushed_results.errors

def do_some_work(self) -> bool:
did_some_work = False

# Sample some stats.
if self.current_task_shots_left > 0:
# Don't keep sampling if we've exceeded the number of errors needed.
if self.current_error_cutoff is not None and self.current_error_cutoff <= 0:
Expand All @@ -184,17 +196,20 @@ def do_some_work(self) -> bool:
assert isinstance(some_work_done, AnonTaskStats)
self.current_task_shots_left -= some_work_done.shots
if self.current_error_cutoff is not None:
if self.custom_error_count_key is not None:
self.current_error_cutoff -= some_work_done.custom_counts[self.custom_error_count_key]
else:
self.current_error_cutoff -= some_work_done.errors
self.current_error_cutoff -= self.num_unflushed_errors()
self.unflushed_results += some_work_done
did_some_work = True

# Report them periodically.
should_flush = False
if self.num_unflushed_errors() >= self.soft_error_flush_threshold:
should_flush = True
if self.unflushed_results.shots > 0:
if self.current_task_shots_left <= 0 or self.last_flush_message_time + self.cur_flush_period < time.monotonic():
self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period)
did_some_work |= self.flush_results()
should_flush = True
if should_flush:
self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period)
did_some_work |= self.flush_results()

return did_some_work

Expand Down
1 change: 1 addition & 0 deletions glue/sample/src/sinter/_plotting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_better_sorted_str_terms():
assert f('a1.5.3b2') == ('a', (1, 5, 3), 'b', 2)
assert f(1) < f(None)
assert f(1) < f('2')
assert f('2') > f(1)
assert sorted([
"planar d=10 r=30",
"planar d=16 r=36",
Expand Down

0 comments on commit a4574fb

Please sign in to comment.