diff --git a/docker/presets/llama-2/Dockerfile b/docker/presets/llama-2/Dockerfile index 250d525de..1accbdcb6 100644 --- a/docker/presets/llama-2/Dockerfile +++ b/docker/presets/llama-2/Dockerfile @@ -17,6 +17,8 @@ RUN git clone https://github.com/facebookresearch/llama WORKDIR /workspace/llama +RUN sed -i $'/torch.distributed.init_process_group("nccl")/c\\\t\t\timport datetime\\\n\\\t\t\ttorch.distributed.init_process_group("nccl", timeout=datetime.timedelta(days=365*100))' /workspace/llama/llama/generation.py + RUN pip install -e . RUN pip install fastapi pydantic RUN pip install 'uvicorn[standard]' diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index fcb1c7cc1..3f2313bd3 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -212,6 +212,8 @@ def worker_listen_tasks(): os.killpg(os.getpgrp(), signal.SIGTERM) except Exception as e: print(f"Error in Worker Listen Task", e) + if 'Socket Timeout' in str(e): + print("A socket timeout occurred.") os.killpg(os.getpgrp(), signal.SIGTERM) if __name__ == "__main__": diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 7e1810049..0d7b61ec6 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -201,6 +201,8 @@ def worker_listen_tasks(): os.killpg(os.getpgrp(), signal.SIGTERM) except Exception as e: print(f"Error in Worker Listen Task", e) + if 'Socket Timeout' in str(e): + print("A socket timeout occurred.") os.killpg(os.getpgrp(), signal.SIGTERM) if __name__ == "__main__":