-
Notifications
You must be signed in to change notification settings - Fork 2
/
gpu_helper.py
191 lines (147 loc) · 5.78 KB
/
gpu_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import config
import xml.etree.ElementTree as ET
import pwd
import logging
from datetime import datetime, timedelta
import threading
import pickle
import os
import traceback
from multiprocessing.dummy import Pool as ThreadPool
import paramiko
import socket
logger = logging.getLogger()
# init cache
cache = {
'users': {},
'since': datetime.now()
}
# init template
user_template = {'name': '', 'inc': 0, 'time': timedelta(seconds=0.), 'cum_energy': 0., 'cum_util': 0.}
def get_nvidiasmi(ssh):
# function to get nvidia smi xmls
_, ssh_stdout, _ = ssh.exec_command('nvidia-smi -q -x')
try:
ret = ET.fromstring(''.join(ssh_stdout.readlines()))
return ret
except ET.ParseError:
return False
def get_ps(ssh, pids):
# function to identify processes running
pid_cmd = 'ps -o pid= -o ruser= -p {}'.format(','.join(pids))
_, ssh_stdout, _ = ssh.exec_command(pid_cmd)
res = ''.join(ssh_stdout.readlines())
return res
def get_users_by_pid(ps_output):
# function to identify user of a process
users_by_pid = {}
if ps_output is None:
return users_by_pid
for line in ps_output.strip().split('\n'):
pid, user = line.split()
users_by_pid[pid] = user
return users_by_pid
def update_users(info):
# updates cache of user usage statistics
for user, real_user in zip(info['users'], info['real_users']):
if user not in cache['users']:
cache['users'][user] = {}
cache['users'][user].update(user_template)
cache['users'][user]['name'] = real_user
cache['users'][user]['inc'] += 1
cache['users'][user]['time'] += config.update_interval
pwr_float = float(info['power_draw'][:-2])
cache['users'][user]['cum_energy'] += pwr_float * config.update_interval.total_seconds() / 3600
gpu_util = float(info['gpu_util'][:-2])
cache['users'][user]['cum_util'] += gpu_util
def get_gpu_infos(ssh):
# collects gpu usage information for a ssh connection
nvidiasmi_output = get_nvidiasmi(ssh)
if not nvidiasmi_output:
return False
gpus = nvidiasmi_output.findall('gpu')
gpu_infos = []
for idx, gpu in enumerate(gpus):
model = gpu.find('product_name').text
power_draw = gpu.find('power_readings').find('power_draw').text
processes = gpu.findall('processes')[0]
pids = [process.find('pid').text for process in processes if config.process_filter.search(process.find('process_name').text)]
mem = gpu.find('fb_memory_usage').find('total').text
gpu_util = gpu.find('utilization').find('gpu_util').text
used_mem = gpu.find('fb_memory_usage').find('used').text
free = (len(pids) == 0)
info = {
'idx': idx,
'model': model,
'pids': pids,
'power_draw': power_draw,
'free': free,
'mem': mem,
'gpu_util': gpu_util,
'used_mem': used_mem
}
if free:
users = []
real_users = []
else:
ps_output = get_ps(ssh, pids)
users_by_pid = get_users_by_pid(ps_output)
users = set((users_by_pid[pid] for pid in pids))
real_users = [pwd.getpwnam(user).pw_gecos.split(',')[0] for user in users]
info['users'] = users
info['real_users'] = real_users
update_users(info)
gpu_infos.append(info)
return gpu_infos
def get_remote_info(server):
# returns gpu information from cache
tstring = cache['servers'][server]['time'].strftime('%d.%m.%Y %H:%M:%S')
logger.info(f'Using cache for {server} from {tstring}')
return cache['servers'][server]
def get_new_server_info(server):
server_info = {}
try:
ssh = paramiko.SSHClient()
# be careful to change this if you don't trust to add the hostkeys automatically
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logging.info(f'loading server {server}')
ssh.connect(server, username=config.user, password=config.password, key_filename=config.key)
try:
gpu_infos = get_gpu_infos(ssh)
if not gpu_infos:
server_info['smi_error'] = True
else:
server_info['info'] = gpu_infos
server_info['smi_error'] = False
server_info['time'] = datetime.now()
logging.info(f'finished loading server {server}')
finally:
ssh.close()
del ssh
except Exception:
logging.error(f'Had an issue while updating cache for {server}: {traceback.format_exc()}')
return server, server_info
def update_cache(server, interval):
# asyncronously updates cache if interval is passed
logging.info('updating cache')
server, result = get_new_server_info(server)
if result:
cache['servers'][server].update(result)
logging.info('restarting timer to update chache')
threading.Timer(interval.total_seconds(), update_cache, (server, interval, )).start()
def write_cache(interval):
with open(config.cache_file, 'wb') as f:
pickle.dump(cache, f)
threading.Timer(interval.total_seconds(), write_cache, (interval, )).start()
def start_async(interval):
# asyncronously updates cache
for server in config.servers:
threading.Thread(target=update_cache, args=(server, interval, )).start()
threading.Thread(target=write_cache, args=(interval, )).start()
def setup():
cache['servers'] = {server: {'time': datetime.fromtimestamp(0.), 'info': []} for server in config.servers}
# start async updates of cache
if os.path.isfile(config.cache_file):
with open(config.cache_file, 'rb') as f:
cache.update(pickle.load(f))
start_async(config.update_interval)