Skip to content

Commit

Permalink
Make Partition into dedicated dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
mr0re1 committed Jan 20, 2025
1 parent 8b711f6 commit f047711
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 "Google LLC"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Dict

from dataclasses import dataclass, field

@dataclass(frozen=True)
class Partition:
name: str
enable_job_exclusive: bool = False
conf: Dict[str, Any] = field(default_factory=dict)

nodesets: List[str] = field(default_factory=list)
nodesets_dyn: List[str] = field(default_factory=list)
nodesets_tpu: List[str] = field(default_factory=list)

@property
def is_tpu(self) -> bool:
return len(self.nodesets_tpu) > 0

@property
def any_dynamic(self) -> bool:
return len(self.nodesets_dyn) > 0

@classmethod
def from_json(cls, jo: dict) -> "Partition":
return cls(
name=jo["partition_name"],
enable_job_exclusive=jo["enable_job_exclusive"],
conf=jo.get("partition_conf", {}),

nodesets=jo.get("nodesets", []),
nodesets_dyn=jo.get("nodesets_dyn", []),
nodesets_tpu=jo.get("nodesets_tpu", []),
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import util
from util import dirs, slurmdirs
import tpu
from base import Partition

FILE_PREAMBLE = """
# Warning:
Expand Down Expand Up @@ -76,13 +77,9 @@ def get(key, default):
for nodeset in lkp.cfg.nodeset.values()
)

any_tpu = any(
tpu_nodeset is not None
for part in lkp.cfg.partitions.values()
for tpu_nodeset in part.partition_nodeset_tpu
)
any_tpu = any(p.is_tpu for p in lkp.partitions)
any_dynamic = any(p.any_dynamic for p in lkp.partitions)

any_dynamic = any(bool(p.partition_feature) for p in lkp.cfg.partitions.values())
comma_params = {
"LaunchParameters": [
"enable_nss_slurm",
Expand Down Expand Up @@ -180,7 +177,7 @@ def nodeset_dyn_lines(nodeset):
)


def partitionlines(partition, lkp: util.Lookup) -> str:
def partitionlines(partition: Partition, lkp: util.Lookup) -> str:
"""Make a partition line for the slurm.conf"""
MIN_MEM_PER_CPU = 100

Expand All @@ -192,32 +189,23 @@ def defmempercpu(nodeset_name: str) -> int:
return max(MIN_MEM_PER_CPU, (machine.memory - mem_spec_limit) // machine.cpus)

defmem = min(
map(defmempercpu, partition.partition_nodeset), default=MIN_MEM_PER_CPU
)

nodesets = list(
chain(
partition.partition_nodeset,
partition.partition_nodeset_dyn,
partition.partition_nodeset_tpu,
)
map(defmempercpu, partition.nodesets), default=MIN_MEM_PER_CPU
)

is_tpu = len(partition.partition_nodeset_tpu) > 0
is_dyn = len(partition.partition_nodeset_dyn) > 0
nodesets = list(chain(partition.nodesets, partition.nodesets_dyn, partition.nodesets_tpu))

oversub_exlusive = partition.enable_job_exclusive or is_tpu
power_down_on_idle = partition.enable_job_exclusive and not is_dyn
oversub_exlusive = partition.enable_job_exclusive or partition.is_tpu
power_down_on_idle = partition.enable_job_exclusive and not partition.any_dynamic

line_elements = {
"PartitionName": partition.partition_name,
"PartitionName": partition.name,
"Nodes": ",".join(nodesets),
"State": "UP",
"DefMemPerCPU": defmem,
"SuspendTime": 300,
"Oversubscribe": "Exclusive" if oversub_exlusive else None,
"PowerDownOnIdle": "YES" if power_down_on_idle else None,
**partition.partition_conf,
**partition.conf,
}

return dict_to_conf(line_elements)
Expand All @@ -231,12 +219,8 @@ def suspend_exc_lines(lkp: util.Lookup) -> Iterable[str]:
static_nodelists.append(nodelist)
suspend_exc_nodes = {"SuspendExcNodes": static_nodelists}

dyn_parts = [
p.partition_name
for p in lkp.cfg.partitions.values()
if len(p.partition_nodeset_dyn) > 0
]
suspend_exc_parts = {"SuspendExcParts": [*dyn_parts]}
dyn_parts = [p.name for p in lkp.partitions if p.any_dynamic]
suspend_exc_parts = {"SuspendExcParts": dyn_parts}

return filter(
None,
Expand All @@ -255,7 +239,7 @@ def make_cloud_conf(lkp: util.Lookup) -> str:
*(nodeset_lines(n, lkp) for n in lkp.cfg.nodeset.values()),
*(nodeset_dyn_lines(n) for n in lkp.cfg.nodeset_dyn.values()),
*(nodeset_tpu_lines(n, lkp) for n in lkp.cfg.nodeset_tpu.values()),
*(partitionlines(p, lkp) for p in lkp.cfg.partitions.values()),
*(partitionlines(p, lkp) for p in lkp.partitions),
*(suspend_exc_lines(lkp)),
]
return "\n\n".join(filter(None, lines))
Expand Down Expand Up @@ -341,11 +325,7 @@ def install_cgroup_conf(lkp: util.Lookup) -> None:

def install_jobsubmit_lua(lkp: util.Lookup) -> None:
"""install job_submit.lua if there are tpu nodes in the cluster"""
if not any(
tpu_nodeset is not None
for part in lkp.cfg.partitions.values()
for tpu_nodeset in part.partition_nodeset_tpu
):
if not any(p.is_tpu for p in lkp.partitions):
return # No TPU partitions, no need for job_submit.lua

scripts_dir = lkp.cfg.slurm_scripts_dir or dirs.scripts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import argparse
import util
import tpu
from base import Partition


def get_vmcount_of_tpu_part(part):
def get_vmcount_of_part(part: Partition):
res = 0
lkp = util.lookup()
for ns in lkp.cfg.partitions[part].partition_nodeset_tpu:
for ns in part.nodesets_tpu:
tpu_obj = tpu.TPU.make(ns, lkp)
if res == 0:
res = tpu_obj.vmcount
Expand Down Expand Up @@ -54,19 +54,19 @@ def get_vmcount_of_tpu_part(part):
vmcounts = []
# valid equals to 0 means that we are ok, otherwise it will be set to one of the previously defined exit codes
valid = 0
for part in args.partitions.split(","):
if part not in util.lookup().cfg.partitions:
for part_name in args.partitions.split(","):
try:
part = util.lookup().partition(part_name)
except:
valid = PART_INVALID
break
else:
if util.lookup().partition_is_tpu(part):
vmcount = get_vmcount_of_tpu_part(part)
if vmcount == -1:
valid = DIFF_VMCOUNTS_SAME_PART
break
vmcounts.append(vmcount)
else:
vmcounts.append(0)
vmcount = get_vmcount_of_part(part)
if vmcount == -1:
valid = DIFF_VMCOUNTS_SAME_PART
break
vmcounts.append(vmcount)

# this means that there are different vmcounts for these partitions
if valid == 0 and len(set(vmcounts)) != 1:
valid = DIFF_PART_DIFFERENT_VMCOUNTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: u

# expand all exclusive job nodelists
for job in resume_data.jobs:
if not lkp.cfg.partitions[job.partition].enable_job_exclusive:
if not lkp.partition(job.partition).enable_job_exclusive:
continue

groups[job.job_id] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ class TstNodeset:
enable_placement: bool = True
placement_max_distance: Optional[int] = None

@dataclass
class TstPartition:
partition_name: str = "euler"
partition_nodeset: list[str] = field(default_factory=list)
partition_nodeset_tpu: list[str] = field(default_factory=list)
enable_job_exclusive: bool = False

@dataclass
class TstCfg:
slurm_cluster_name: str = "m22"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import unittest
import tempfile

from common import TstCfg, TstNodeset, TstPartition, TstTPU # needed to import util
from common import TstCfg, TstNodeset, TstTPU # needed to import util
import util
import resume
from resume import ResumeData, ResumeJobData, BulkChunk, PlacementAndNodes
Expand Down Expand Up @@ -74,11 +74,11 @@ def test_group_nodes_bulk(mock_create_placements, mock_tpu):
"t": TstNodeset(nodeset_name="t"),
},
partitions={
"p1": TstPartition(
"p1": dict(
partition_name="p1",
enable_job_exclusive=True,
),
"p2": TstPartition(
"p2": dict(
partition_name="p2",
partition_nodeset_tpu=["t"],
enable_job_exclusive=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

import yaml # noqa: E402
from addict import Dict as NSDict # noqa: E402

from base import Partition

USER_AGENT = "Slurm_GCP_Scripts/1.5 (GPN:SchedMD)"
ENV_CONFIG_YAML = os.getenv("SLURM_CONFIG_YAML")
Expand Down Expand Up @@ -536,8 +536,7 @@ def _assemble_config(
# add partition configs
for p_yaml in partitions:
p_cfg = NSDict(p_yaml)
assert p_cfg.get("partition_name"), "partition_name is required"
p_name = p_cfg.partition_name
p_name = Partition.from_json(p_cfg).name # + de-serialization check
assert p_name not in cfg.partitions, f"partition {p_name} already defined"
cfg.partitions[p_name] = p_cfg

Expand Down Expand Up @@ -1314,6 +1313,15 @@ def hostname_fqdn(self):
def zone(self):
return instance_metadata("zone")


@property
def partitions(self) -> List[Partition]:
return [Partition.from_json(jo) for jo in self.cfg.partitions]

def partition(self, name: str) -> Partition:
return Partition.from_json(self.cfg.partitions[name])


node_desc_regex = re.compile(
r"^(?P<prefix>(?P<cluster>[^\s\-]+)-(?P<nodeset>\S+))-(?P<node>(?P<suffix>\w+)|(?P<range>\[[\d,-]+\]))$"
)
Expand Down Expand Up @@ -1351,11 +1359,6 @@ def node_nodeset(self, node_name=None):

return self.cfg.nodeset[nodeset_name]

def partition_is_tpu(self, part: str) -> bool:
"""check if partition with name part contains a nodeset of type tpu"""
return len(self.cfg.partitions[part].partition_nodeset_tpu) > 0


def node_is_tpu(self, node_name=None):
nodeset_name = self.node_nodeset_name(node_name)
return self.cfg.nodeset_tpu.get(nodeset_name) is not None
Expand Down

0 comments on commit f047711

Please sign in to comment.