Skip to content

Commit

Permalink
Merge pull request #229 from JamesKunstle/fix-interruptable-subprocess
Browse files Browse the repository at this point in the history
changes StreamablePopen to return a process and implement listening
  • Loading branch information
mergify[bot] authored Sep 26, 2024
2 parents 7bc49bb + 5d1aee3 commit 960b9d8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
27 changes: 20 additions & 7 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,33 +726,46 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--fsdp_sharding_strategy={train_args.fsdp_options.sharding_strategy.value}"
)

print(f"\033[92mRunning command: {' '.join(command)}\033[0m")
print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m")
process = None
interrupt: KeyboardInterrupt | Exception | None = None
try:
process = StreamablePopen(
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
command,
)

except KeyboardInterrupt:
print("Process interrupted by user")
process.listen()
except KeyboardInterrupt as e:
print("Training subprocess interrupted by user.")
interrupt = e
except Exception as e:
print(f"An error occurred: {str(e)}")
interrupt = e
finally:
if "process" not in locals() or process is None:
return
if process.poll() == 0:
print("\033[92mOperation completed successfully! 🎉\033[0m")
print("\033[92mTraining subprocess exited successfully! 🎉\033[0m")
else:
print("\033[91mOperation failed, terminating process.\033[0m")
print(
"\033[91mTraining subprocess has not exited yet. Sending SIGTERM.\033[0m"
)

print("Sending interrupt signal to Training subprocess.")
process.terminate()
try:
print("Waiting for process to exit, 60s...")
process.wait(timeout=60)
except subprocess.TimeoutExpired:
print("\033[91mProcess did not terminate in time, killing it.\033[0m")
print(
"\033[91mTraining subprocess did not terminate before timeout, sending SIGKILL.\033[0m"
)
process.kill()

if interrupt:
print(f"Error caught from training subprocess.: {interrupt}")
raise interrupt


if __name__ == "__main__":
# TODO(osilkin): Configure a type that these args must adhere to for the sake of type checking
Expand Down
5 changes: 4 additions & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ def __init__(self, output_file, *args, **kwargs):
# remove the stderr and stdout from kwargs
kwargs.pop("stderr", None)
kwargs.pop("stdout", None)
self.output_file = output_file

super().__init__(
*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs
)
with open(output_file, "wb") as full_log_file:

def listen(self):
with open(self.output_file, "wb") as full_log_file:
while True:
byte = self.stdout.read(1)
if byte:
Expand Down

0 comments on commit 960b9d8

Please sign in to comment.