Skip to content

Commit

Permalink
Refactor registry.
Browse files Browse the repository at this point in the history
  • Loading branch information
knighton committed Jan 22, 2024
1 parent db3041f commit 7470d27
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 125 deletions.
4 changes: 2 additions & 2 deletions streaming/base/coord/job/dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def manual_unregister(self) -> None:
This job must be registered when this is called.
"""
self.registry.ensure_unregistered(self.job_hash, self.world)
self.registry.unregister(self.job_hash, self.world, True)

def __del__(self) -> None:
"""Destructor.
You may unregister the job explicitly ahead of time (to ensure it happens synchronously
instead of eventually).
"""
self.registry.ensure_unregistered(self.job_hash, self.world)
self.registry.unregister(self.job_hash, self.world, False)
6 changes: 5 additions & 1 deletion streaming/base/coord/job/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def __init__(self, jobs: List[JobEntry]) -> None:
@classmethod
def read(cls, filename: str) -> Self:
if os.path.exists(filename):
obj = json.load(open(filename))
try:
obj = json.load(open(filename))
except:
os.remove(filename)
obj = {}
else:
obj = {}
jobs = obj.get('jobs') or []
Expand Down
190 changes: 68 additions & 122 deletions streaming/base/coord/job/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Useful for detecting collisions between different jobs' local dirs.
"""

import gc
import os
from hashlib import sha3_224
from shutil import rmtree
Expand Down Expand Up @@ -130,7 +129,7 @@ def _hash_streams(self, streams: Sequence[Stream]) -> Tuple[List[str], List[str]

return stream_locals, stream_hashes, job_hash

def _make_dir(self, job_hash: str) -> None:
def _make_job_dir(self, job_hash: str) -> None:
"""Create a Streaming job config dir.
Args:
Expand All @@ -139,7 +138,7 @@ def _make_dir(self, job_hash: str) -> None:
dirname = os.path.join(self.config_root, job_hash)
os.makedirs(dirname)

def _remove_dir(self, job_hash: str) -> None:
def _remove_job_dir(self, job_hash: str) -> None:
"""Delete a Streaming job config dir.
Args:
Expand All @@ -148,101 +147,60 @@ def _remove_dir(self, job_hash: str) -> None:
dirname = os.path.join(self.config_root, job_hash)
rmtree(dirname)

def _do_register(self, streams: Sequence[Stream]) -> str:
"""Register this collection of StreamingDataset replicas.
def register(self, streams: Sequence[Stream], world: World) -> str:
"""Register or look up this collection of StreamingDataset replicas.
Called by the local leader.
Called by all ranks.
Args:
streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in
combination with process IDs and creation times lets us uniquely identify a
Streaming job.
world (World): Rank-wise world state.
Returns:
str: Streaming config subdir for this job.
str: Subdir for this collection of StreamingDataset replicas.
"""
register_time = time_ns()
pid2create_time = self._get_live_procs()
pid = os.getpid()
create_time = pid2create_time.get(pid)
if create_time is None:
raise RuntimeError('`psutil` thinks we are dead, and yet here we are: pid = {pid}.')
if not world.is_local_leader:
_, _, job_hash = self._hash_streams(streams)
dirname = os.path.join(self.config_root, job_hash)
wait_for_creation(dirname, self.timeout, self.tick, self.lock)
return job_hash

# Collect our stream locals and hash them, resulting in a job hash.
stream_locals, stream_hashes, job_hash = self._hash_streams(streams)

entry = JobEntry(job_hash=job_hash,
stream_hashes=stream_hashes,
stream_locals=stream_locals,
process_id=pid,
register_time=register_time)

with self.lock:
conf = RegistryFile.read(self.registry_filename)
conf.add(entry)
del_job_hashes = conf.filter(pid2create_time)
conf.write(self.registry_filename)
map(self._remove_dir, del_job_hashes)
self._make_dir(job_hash)

return job_hash

def _lookup(self, streams: Sequence[Stream]) -> str:
"""Look up this collection of StreamingDataset replicas.
Called by the local leader.
Args:
streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in
combination with process IDs and creation times lets us uniquely identify a
Streaming job.
Returns:
str: Streaming config subdir for this job.
"""
_, _, job_hash = self._hash_streams(streams)
return job_hash

def register(self, streams: Sequence[Stream], world: World) -> str:
"""Register or look up this collection of StreamingDataset replicas.
Called by all ranks.
# Get registration time.
register_time = time_ns()

Note: we explicitly garbage collect right before registration. This is to save us from the
following scenario:
# Load the job database.
db = RegistryFile.read(self.registry_filename)

```py
dataset = StreamingDataset(...)
# Perform liveness checks on the jobs we have registered.
pid2create_time = self._get_live_procs()
del_job_hashes = db.filter(pid2create_time)

del dataset
# Add an entry for this job.
pid = os.getpid()
create_time = pid2create_time.get(pid)
if create_time is None:
raise RuntimeError('`psutil` thinks we are dead, and yet here we are: pid {pid}.')
entry = JobEntry(job_hash=job_hash,
stream_hashes=stream_hashes,
stream_locals=stream_locals,
process_id=pid,
register_time=register_time)
db.add(entry)

# The dataset is marked deleted, but the python garbage collector does not execute
# dataset.__del__ in time, much less dataset.job.__del__ in time, which would have
# automatically ensured the job was un-registered.
# Save the new db to disk.
db.write(self.registry_filename)

dataset = StreamingDataset(...) # Same locals as before.
# *Boom*, due to "reused" locals, matching the locals still registered from the first
# time the dataset was created.
```
Args:
streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in
combination with process IDs and creation times lets us uniquely identify a
Streaming job.
world (World): Rank-wise world state.
# Add and remove job directories accordingly.
self._make_job_dir(job_hash)
map(self._remove_job_dir, del_job_hashes)

Returns:
str: Subdir for this collection of StreamingDataset replicas.
"""
gc.collect()
if world.is_local_leader:
job_hash = self._do_register(streams)
else:
job_hash = self._lookup(streams)
dirname = os.path.join(self.config_root, job_hash)
wait_for_creation(dirname, self.timeout, self.tick, self.lock)
return job_hash
return job_hash

def is_registered(self, job_hash: str) -> bool:
"""Tell whether the given job_hash is registered.
Expand All @@ -255,62 +213,50 @@ def is_registered(self, job_hash: str) -> bool:
Returns:
bool: Whether the job hash is registered.
"""
dirname = os.path.join(self.config_root, job_hash)
with self.lock:
conf = RegistryFile.read(self.registry_filename)
return conf.contains(job_hash)

def _do_unregister(self, job_hash: str) -> None:
"""Unregister this collection of StreamingDataset replicas.
Called by the local leader.
Args:
job_hash (str): Subdir identifying this Streaming job.
"""
pid2create_time = self._get_live_procs()
return os.path.isdir(dirname)

with self.lock:
conf = RegistryFile.read(self.registry_filename)
conf.remove(job_hash)
del_job_hashes = conf.filter(pid2create_time)
conf.write(self.registry_filename)
map(self._remove_dir, del_job_hashes)
self._remove_dir(job_hash)

def unregister(self, job_hash: str, world: World) -> None:
def unregister(self, job_hash: str, world: World, strict: bool = True) -> None:
"""Unregister this collection of StreamingDataset replicas.
Called by all ranks.
Args:
job_hash (str): Subdir identifying this Streaming job.
world (World): Rank-wise world state.
strict (bool): If strict, require the job to be currently registered at start.
"""
if world.is_local_leader:
self._do_unregister(job_hash)
else:
if not world.is_local_leader:
dirname = os.path.join(self.config_root, job_hash)
wait_for_deletion(dirname, self.timeout, self.tick, self.lock)
return

def ensure_unregistered(self, job_hash: str, world: World) -> None:
"""Ensure that this collection of StreamingDataset replicas is unregistered.
with self.lock:
# Load the job database.
db = RegistryFile.read(self.registry_filename)

Called by all ranks.
# Check if the job hash is registered.
was_registered = db.contains(job_hash)

Args:
job_hash (str): Subdir identifying this Streaming job.
world (World): Rank-wise world state.
"""
pid2create_time = self._get_live_procs()
# If strict, require the job to be registered.
if strict and not was_registered:
raise ValueError(f'Attempted to unregister job {job_hash}, but it was not ' +
f'registered.')

with self.lock:
conf = RegistryFile.read(self.registry_filename)
is_registered = conf.contains(job_hash)
if not is_registered:
return

conf.remove(job_hash)
del_job_hashes = conf.filter(pid2create_time)
conf.write(self.registry_filename)
map(self._remove_dir, del_job_hashes)
self._remove_dir(job_hash)
# Unregister the job, if it is registered.
if was_registered:
db.remove(job_hash)
self._remove_job_dir(job_hash)

# Perform liveness checks on the jobs we have registered.
pid2create_time = self._get_live_procs()
del_job_hashes = db.filter(pid2create_time)

# If we unregistered the job and/or we garbage collected job(s), save the new jobs
# database back to disk.
if was_registered or del_job_hashes:
db.write(self.registry_filename)

# Remove each directory corresponding to a job that was garbage collected.
map(self._remove_job_dir, del_job_hashes)
2 changes: 2 additions & 0 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,8 @@ def _request_pregen_epoch(self, epoch: int, sample: int) -> None:
def _each_pregen_epoch_todo(self) -> Iterator[Tuple[int, int]]:
lock_filename = self.job.get_filename(self.pregen_todos_lock_path)
todo_filename = self.job.get_filename(self.pregen_todos_path)
dirname = os.path.dirname(lock_filename)
os.makedirs(dirname, exist_ok=True)
lock = FileLock(lock_filename)
while True:
with lock:
Expand Down

0 comments on commit 7470d27

Please sign in to comment.