-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathproject_setup.py
96 lines (82 loc) · 2.47 KB
/
project_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import importlib
import mlrun
def assert_build():
for module_name in [
"torch",
"transformers",
"datasets",
"accelerate",
"evaluate",
"deepspeed",
"mpi4py",
]:
module = importlib.import_module(module_name)
print(module.__version__)
def setup(
project: mlrun.projects.MlrunProject
):
"""
Creating the project for this demo.
:returns: a fully prepared project for this demo.
"""
print(project.get_param("source"))
# Set or build the default image:
if project.get_param("default_image") is None:
print("Building image for the demo:")
image_builder = project.set_function(
"project_setup.py",
name="image-builder",
handler="assert_build",
kind="job",
image="mlrun/ml-models-gpu",
requirements=[
"torch",
"transformers[deepspeed]",
"datasets",
"accelerate",
"evaluate",
"mpi4py",
],
)
assert image_builder.deploy()
default_image = image_builder.spec.image
project.set_default_image(project.get_param("default_image"))
# Set the project git source:
project.set_source(project.get_param("source"), pull_at_runtime=True)
# Set the data collection function:
data_collection_function = project.set_function(
"src/data_collection.py",
name="data-collecting",
image="mlrun/mlrun",
kind="job",
)
data_collection_function.apply(mlrun.auto_mount())
data_collection_function.save()
# Set the data preprocessing function:
project.set_function(
"src/data_preprocess.py",
name="data-preparing",
kind="job",
)
# Set the training function:
train_function = project.set_function(
"src/trainer.py",
name="training",
kind="job",
)
train_function.with_limits(
gpus=project.get_param("num_gpus_per_replica") or 4,
cpu=project.get_param("num_cpus_per_replica") or 48,
mem=project.get_param("memory_per_replica") or "192Gi",
)
train_function.save()
project.set_function(
"src/serving.py",
name="serving",
kind="serving",
)
# Set the training workflow:
project.set_workflow("training_workflow", "src/training_workflow.py")
# Save and return the project:
project.save()
return project