-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: Segfault with large amount of data #380
Comments
(base) kent@kent-Super-Server:~$ nvidia-smi
Sun Dec 15 23:58:02 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 On | 00000000:17:00.0 Off | N/A |
| 60% 40C P8 37W / 250W | 382MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 3090 On | 00000000:31:00.0 Off | N/A |
| 60% 37C P8 30W / 250W | 18MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA GeForce RTX 3090 On | 00000000:4B:00.0 Off | N/A |
| 60% 35C P8 32W / 250W | 18MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA GeForce RTX 3090 On | 00000000:CA:00.0 On | N/A |
| 60% 40C P8 35W / 250W | 250MiB / 24576MiB | 4% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 27710 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 1923710 C /usr/local/bin/python 358MiB |
| 1 N/A N/A 27710 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 27710 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 27710 G /usr/lib/xorg/Xorg 126MiB |
| 3 N/A N/A 28022 G /usr/bin/gnome-shell 32MiB |
| 3 N/A N/A 2718704 G ...erProcess --variations-seed-version 35MiB |
| 3 N/A N/A 3856584 G /usr/bin/nautilus 27MiB |
+-----------------------------------------------------------------------------------------+
(base) kent@kent-Super-Server:~$ |
(base) kent@kent-Super-Server:~$ julia
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.11.1 (2024-10-16)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
julia> versioninfo()
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × Intel(R) Xeon(R) Platinum 8336C CPU @ 2.30GHz
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, icelake-server)
Threads: 1 default, 0 interactive, 1 GC (on 128 virtual cores)
Environment:
LD_LIBRARY_PATH = /usr/lib/cuda/lib64:/usr/local/cuda/lib64:
julia>
|
thanks for the report
@wsmoses it seems like there is a hard limit on |
that's likely specific to the device. I think the bigger issue here is that seemingly we're not freeing objects that GC clearly doesn't need |
i checked out out NVIDIA GeForce RTX 3090 specs and it has like 24 GB of RAM, so there still should be some space before running out of memory.
although i'm pretty sure there is stuff we don't free correctly (like compiled functions) i'm not sure that's the problem, because it's clearly saturating the memory with their code. and it fails after performing 2374 allocations of size |
there is a software limit on the max memory of the mem pool here https://github.com/openxla/xla/blob/2421ead495dbfc3b25e96c77c9b5e495b65d0f94/xla/tsl/framework/bfc_allocator.cc#L80C19-L80C31 ...which is called by https://github.com/openxla/xla/blob/2421ead495dbfc3b25e96c77c9b5e495b65d0f94/xla/pjrt/gpu/se_gpu_pjrt_client.cc#L987-L993 and sets a ...which defaults to 0.75 and |
it's definitely not the compiled ffunctions since that's CPU not GPU memory. I think this is ConcreteRArray allocations that aren't freed |
but they are still referenced. they can't be freed. look at
what it's annoying me is that it seems like it's only allocating in 1 GPU? so automatic multi-gpu is not properly configured |
ah if they are still referenced, then yeah I'm not sure what more can be done (besides making this throw a julia error). But yeah we should add more multi-gpu support |
maybe the memory limit / memory fraction of the GPU mem pool can be set by the users when creating a GPU client? it would be more explicit, but the user is freer to choose what i saw inside XLA is that this BPFAllocator with the memory limit is only used for GPUs... |
@x66ccff what happens if you run the same code in Python with Jax? do you get the same error too? |
Sorry i never use jax in python. can you give me the python code and jax version? or just give me the conda cmd. i will try it tomorrow. |
@wsmoses look what i found https://jax.readthedocs.io/gpu_memory_allocation.html mmm since the problem we want to check is allocation, let's just allocate the same arrays import jax
in_dim = 1000000
n = 10000
random_x = []
random_y = []
for i in n:
random_x.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64))
random_y.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64)) you can just install the latest jax version. also would you mind rerunning the same Julia and Python code afterward but with env var |
import jax
from tqdm import tqdm
in_dim = 1000000
n = 10000
random_x = []
random_y = []
for i in tqdm(range(n)):
random_x.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64))
random_y.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64)) (jaxtest) kent@kent-Super-Server:~/_Project/PTSjl/PTS.jl$ python test_jax.py
0%| | 0/10000 [00:00<?, ?it/s]/home/kent/_Project/PTSjl/PTS.jl/test_jax.py:11: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in ones is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
random_x.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64))
/home/kent/_Project/PTSjl/PTS.jl/test_jax.py:12: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in ones is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
random_y.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64))
24%|█████████████████████████████████▋ | 2374/10000 [00:04<00:09, 843.86it/s]2024-12-17 10:09:05.162531: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.81MiB (rounded to 4000000)requested by op
2024-12-17 10:09:05.169419: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E1217 10:09:05.169490 860856 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000 bytes.
24%|█████████████████████████████████▋ | 2374/10000 [00:14<00:45, 167.95it/s]
Traceback (most recent call last):
File "/home/kent/_Project/PTSjl/PTS.jl/test_jax.py", line 11, in <module>
random_x.append(jax.numpy.ones(in_dim, dtype=jax.numpy.float64))
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3398, in ones
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1296, in full
return broadcast(fill_value, shape)
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 829, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 858, in broadcast_in_dim
return broadcast_in_dim_p.bind(
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/core.py", line 416, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/core.py", line 420, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/core.py", line 921, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/kent/anaconda3/envs/jaxtest/lib/python3.9/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000 bytes.
(jaxtest) kent@kent-Super-Server:~/_Project/PTSjl/PTS.jl$ |
with
|
julia with env variable #undef
#undef
#undef
#undef
#undef
#undef
⋮
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
#undef
julia> println("Generating random numbers:")
Generating random numbers:
julia> p_rand = Progress(n; dt=0.5, barglyphs=BarGlyphs("[=> ]"), barlen=50, color=:yellow)
Progress(10000, 0, 50, BarGlyphs('[', '=', '>', ' ', ']'), ProgressMeter.ProgressCore(:yellow, "Progress: ", 0.5, true, 0, Base.TTY(RawFD(23) open, 0 bytes waiting), false, 1, 0, ReentrantLock(nothing, 0x00000000, 0x00, Base.GenericCondition{Base.Threads.SpinLock}(Base.IntrusiveLinkedList{Task}(nothing, nothing), Base.Threads.SpinLock(0)), (139876174128656, 139951551489632, 4)), 0, 1, false, false, 1.734401611764458e9, 1.734401611764458e9, 1.734401611764458e9))
julia> for i in 1:n
random_x[i] = Reactant.ConcreteRArray(randn(in_dim))
random_y[i] = Reactant.ConcreteRArray(randn(in_dim))
next!(p_rand)
end
Progress: 11%[=====> ] ETA: 0:01:25 it stucks here for few seconds (Same as before)nvidia-smi(base) kent@kent-Super-Server:~$ nvidia-smi
Tue Dec 17 10:16:12 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 On | 00000000:17:00.0 Off | N/A |
| 60% 37C P8 37W / 250W | 18759MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 3090 On | 00000000:31:00.0 Off | N/A |
| 60% 34C P8 30W / 250W | 280MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA GeForce RTX 3090 On | 00000000:4B:00.0 Off | N/A |
| 60% 33C P8 31W / 250W | 280MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA GeForce RTX 3090 On | 00000000:CA:00.0 On | N/A |
| 60% 38C P8 35W / 250W | 511MiB / 24576MiB | 5% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 27710 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 867699 C julia 18372MiB |
| 0 N/A N/A 1923710 C /usr/local/bin/python 358MiB |
| 1 N/A N/A 27710 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 867699 C julia 256MiB |
| 2 N/A N/A 27710 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 867699 C julia 256MiB |
| 3 N/A N/A 27710 G /usr/lib/xorg/Xorg 126MiB |
| 3 N/A N/A 28022 G /usr/bin/gnome-shell 32MiB |
| 3 N/A N/A 867699 C julia 256MiB |
| 3 N/A N/A 2718704 G ...erProcess --variations-seed-version 35MiB |
| 3 N/A N/A 3856584 G /usr/bin/nautilus 27MiB |
+-----------------------------------------------------------------------------------------+ 2024-12-17 10:13:53.177348: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf57e6c00 of size 8000000 next 2311
2024-12-17 10:13:53.177361: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf5f87e00 of size 8000000 next 2312
2024-12-17 10:13:53.177376: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf6729000 of size 8000000 next 2313
2024-12-17 10:13:53.177392: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf6eca200 of size 8000000 next 2314
2024-12-17 10:13:53.177407: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf766b400 of size 8000000 next 2315
2024-12-17 10:13:53.177422: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf7e0c600 of size 8000000 next 2316
2024-12-17 10:13:53.177436: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf85ad800 of size 8000000 next 2317
2024-12-17 10:13:53.177451: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf8d4ea00 of size 8000000 next 2318
2024-12-17 10:13:53.177465: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf94efc00 of size 8000000 next 2319
2024-12-17 10:13:53.177480: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bf9c90e00 of size 8000000 next 2320
2024-12-17 10:13:53.177497: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfa432000 of size 8000000 next 2321
2024-12-17 10:13:53.177509: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfabd3200 of size 8000000 next 2322
2024-12-17 10:13:53.177522: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfb374400 of size 8000000 next 2323
2024-12-17 10:13:53.177534: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfbb15600 of size 8000000 next 2324
2024-12-17 10:13:53.177546: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfc2b6800 of size 8000000 next 2325
2024-12-17 10:13:53.177558: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfca57a00 of size 8000000 next 2326
2024-12-17 10:13:53.177571: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfd1f8c00 of size 8000000 next 2327
2024-12-17 10:13:53.177588: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfd999e00 of size 8000000 next 2328
2024-12-17 10:13:53.177602: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfe13b000 of size 8000000 next 2329
2024-12-17 10:13:53.177617: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfe8dc200 of size 8000000 next 2330
2024-12-17 10:13:53.177634: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bff07d400 of size 8000000 next 2331
2024-12-17 10:13:53.177649: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bff81e600 of size 8000000 next 2332
2024-12-17 10:13:53.177663: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3bfffbf800 of size 8000000 next 2333
2024-12-17 10:13:53.177677: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c00760a00 of size 8000000 next 2334
2024-12-17 10:13:53.231843: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c00f01c00 of size 8000000 next 2335
2024-12-17 10:13:53.231907: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c016a2e00 of size 8000000 next 2336
2024-12-17 10:13:53.231927: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c01e44000 of size 8000000 next 2337
2024-12-17 10:13:53.231949: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c025e5200 of size 8000000 next 2338
2024-12-17 10:13:53.231964: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c02d86400 of size 8000000 next 2339
2024-12-17 10:13:53.231980: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c03527600 of size 8000000 next 2340
2024-12-17 10:13:53.231992: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c03cc8800 of size 8000000 next 2341
2024-12-17 10:13:53.232005: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c04469a00 of size 8000000 next 2342
2024-12-17 10:13:53.232017: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c04c0ac00 of size 8000000 next 2343
2024-12-17 10:13:53.232030: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c053abe00 of size 8000000 next 2344
2024-12-17 10:13:53.232042: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c05b4d000 of size 8000000 next 2345
2024-12-17 10:13:53.232054: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c062ee200 of size 8000000 next 2346
2024-12-17 10:13:53.232090: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c06a8f400 of size 8000000 next 2347
2024-12-17 10:13:53.232103: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c07230600 of size 8000000 next 2348
2024-12-17 10:13:53.232115: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c079d1800 of size 8000000 next 2349
2024-12-17 10:13:53.232130: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c08172a00 of size 8000000 next 2350
2024-12-17 10:13:53.232145: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c08913c00 of size 8000000 next 2351
2024-12-17 10:13:53.232160: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c090b4e00 of size 8000000 next 2352
2024-12-17 10:13:53.232174: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c09856000 of size 8000000 next 2353
2024-12-17 10:13:53.232196: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c09ff7200 of size 8000000 next 2354
2024-12-17 10:13:53.232208: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0a798400 of size 8000000 next 2355
2024-12-17 10:13:53.232221: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0af39600 of size 8000000 next 2356
2024-12-17 10:13:53.232233: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0b6da800 of size 8000000 next 2357
2024-12-17 10:13:53.232245: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0be7ba00 of size 8000000 next 2358
2024-12-17 10:13:53.232258: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0c61cc00 of size 8000000 next 2359
2024-12-17 10:13:53.232270: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0cdbde00 of size 8000000 next 2360
2024-12-17 10:13:53.232282: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0d55f000 of size 8000000 next 2361
2024-12-17 10:13:53.232294: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0dd00200 of size 8000000 next 2362
2024-12-17 10:13:53.232312: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0e4a1400 of size 8000000 next 2363
2024-12-17 10:13:53.232328: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0ec42600 of size 8000000 next 2364
2024-12-17 10:13:53.232343: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0f3e3800 of size 8000000 next 2365
2024-12-17 10:13:53.232358: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c0fb84a00 of size 8000000 next 2366
2024-12-17 10:13:53.232372: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c10325c00 of size 8000000 next 2367
2024-12-17 10:13:53.232390: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c10ac6e00 of size 8000000 next 2368
2024-12-17 10:13:53.232405: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c11268000 of size 8000000 next 2369
2024-12-17 10:13:53.232420: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c11a09200 of size 8000000 next 2370
2024-12-17 10:13:53.232439: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c121aa400 of size 8000000 next 2371
2024-12-17 10:13:53.232454: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c1294b600 of size 8000000 next 2372
2024-12-17 10:13:53.232468: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c130ec800 of size 8000000 next 2373
2024-12-17 10:13:53.232486: I external/xla/xla/tsl/framework/bfc_allocator.cc:1109] InUse at 7f3c1388da00 of size 11773440 next 18446744073709551615
2024-12-17 10:13:53.232503: I external/xla/xla/tsl/framework/bfc_allocator.cc:1114] Summary of in-use Chunks by size:
2024-12-17 10:13:53.232531: I external/xla/xla/tsl/framework/bfc_allocator.cc:1117] 2373 Chunks of size 8000000 totalling 17.68GiB
2024-12-17 10:13:53.232552: I external/xla/xla/tsl/framework/bfc_allocator.cc:1117] 1 Chunks of size 11773440 totalling 11.23MiB
2024-12-17 10:13:53.232566: I external/xla/xla/tsl/framework/bfc_allocator.cc:1121] Sum Total of in-use chunks: 17.69GiB
2024-12-17 10:13:53.232579: I external/xla/xla/tsl/framework/bfc_allocator.cc:1123] Total bytes in pool: 18995773440 memory_limit_: 18995773440 available bytes: 0 curr_region_allocation_bytes_: 37991546880
2024-12-17 10:13:53.232603: I external/xla/xla/tsl/framework/bfc_allocator.cc:1128] Stats:
Limit: 18995773440
InUse: 18995773440
MaxInUse: 18995773440
NumAllocs: 2374
MaxAllocSize: 11773440
Reserved: 0
PeakReserved: 0
LargestFreeBlock: 0
2024-12-17 10:13:53.232693: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ****************************************************************************************************
terminate called after throwing an instance of 'xla::XlaRuntimeError'
what(): RESOURCE_EXHAUSTED: Out of memory while trying to allocate 8000000 bytes.
[864501] signal 6 (-6): 已中止
in expression starting at REPL[13]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
gsignal at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x7f4913aa5ffd)
unknown function (ip: 0x7f4913abae9b)
_ZSt9terminatev at /lib/x86_64-linux-gnu/libstdc++.so.6 (unknown line)
__cxa_throw at /lib/x86_64-linux-gnu/libstdc++.so.6 (unknown line)
_ZN3xla12ValueOrThrowISt10unique_ptrINS_10PjRtBufferESt14default_deleteIS2_EEEET_N4absl12lts_202308028StatusOrIS6_EE at /home/kent/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
ArrayFromHostBuffer at /home/kent/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
ArrayFromHostBuffer at /home/kent/.julia/packages/Reactant/sIJRJ/src/XLA.jl:244 [inlined]
#ConcreteRArray#51 at /home/kent/.julia/packages/Reactant/sIJRJ/src/ConcreteRArray.jl:72 [inlined]
ConcreteRArray at /home/kent/.julia/packages/Reactant/sIJRJ/src/ConcreteRArray.jl:65
unknown function (ip: 0x7f4912f05e42)
top-level scope at ./REPL[13]:2
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10088 at /usr/local/julia-1.11.1/share/julia/compiled/v1.11/REPL/u0gqU_GYsA8.so (unknown line)
#1139 at ./client.jl:446
jfptr_YY.1139_14649 at /usr/local/julia-1.11.1/share/julia/compiled/v1.11/REPL/u0gqU_GYsA8.so (unknown line)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_72144.1 at /usr/local/julia-1.11.1/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
true_main at /cache/build/builder-demeter6-6/julialang/julia-master/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/builder-demeter6-6/julialang/julia-master/src/jlapi.c:1059
main at julia (unknown line)
unknown function (ip: 0x7f491442a1c9)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 119792436 (Pool: 119789737; Big: 2699); GC: 763
已中止 (核心已转储)
|
thanks! it seems like the env var is doing nothing, so for the time being we'll keep with a 75% memory on GPUs, but we have to check why is not using multi-GPU automatically |
Currently, we are only passing in |
yeah, that can explain why we are not using multi-GPUs... but how do we do then? |
does JAX do multi-gpu by default? I thought we had to do a EDIT: NVM https://jax.readthedocs.io/en/latest/sharded-computation.html#key-concept-data-sharding. We could expose a sharding API? |
So this should no longer segfault, but throw a nice julia error message |
Closing now in favor of multi-gpu meta issue [since it seems like the bad error message is resolved, and the remaining issue is allocating across multiple gpus] |
The text was updated successfully, but these errors were encountered: