1# slurm-launch.py
2# Usage:
3# python slurm-launch.py --exp-name test \
4# --command "rllib train --run PPO --env CartPole-v0"
5
6import argparse
7import subprocess
8import sys
9import time
10
11from pathlib import Path
12
13template_file = "/lustre/awikner1/res-noise-stabilization/slurm-template.sh"
14JOB_NAME = "${JOB_NAME}"
15NUM_NODES = "${NUM_NODES}"
16NUM_GPUS_PER_NODE = "${NUM_GPUS_PER_NODE}"
17PARTITION_OPTION = "${PARTITION_OPTION}"
18COMMAND_PLACEHOLDER = "${COMMAND_PLACEHOLDER}"
19GIVEN_NODE = "${GIVEN_NODE}"
20LOAD_ENV = "${LOAD_ENV}"
21RUNTIME = "${RUNTIME}"
22ACCOUNT = "${ACCOUNT}"
23MEMORY = "${MEMORY}"
24CPUS = "${CPUS}"
25PARTITION = "${PARTITION}"
26SCRATCH = "${SCRATCH}"
27IFRAY = "${IFRAY}"
28
29if __name__ == "__main__":
30 parser = argparse.ArgumentParser()
31 parser.add_argument(
32 "--ifray",
33 "-r",
34 type=str,
35 default="true",
36 help="Flag whether or not to use ray for parallelization.")
37 parser.add_argument(
38 "--runtime",
39 "-t",
40 type=str,
41 default="15:00",
42 help="Total runtime for the function.")
43 parser.add_argument(
44 "--account",
45 "-A",
46 type=str,
47 default="physics-hi",
48 help="Account to charge cluster time to.")
49 parser.add_argument(
50 "--exp-name",
51 type=str,
52 required=True,
53 help="The job name and path to logging file (exp_name.log).")
54 parser.add_argument(
55 "--num-nodes",
56 "-n",
57 type=int,
58 default=1,
59 help="Number of nodes to use.")
60 parser.add_argument(
61 "--node",
62 "-w",
63 type=str,
64 help="The specified nodes to use. Same format as the "
65 "return of 'sinfo'. Default: ''.")
66 parser.add_argument(
67 "--num-gpus",
68 type=int,
69 default=0,
70 help="Number of GPUs to use in each node. (Default: 0)")
71 parser.add_argument(
72 "--partition",
73 "-p",
74 type=str,
75 )
76 parser.add_argument(
77 "--load-env",
78 type=str,
79 help="The script to load your environment ('module load cuda/10.1')")
80 parser.add_argument(
81 "--command",
82 type=str,
83 required=True,
84 help="The command you wish to execute. For example: "
85 " --command 'python test.py'. "
86 "Note that the command must be a string.")
87 parser.add_argument(
88 "--cpus-per-node",
89 type = int,
90 default = 20,
91 help = "The minimum number of CPUs per node.")
92 parser.add_argument(
93 "--tmp",
94 type = int,
95 default = 10240,
96 help = "The minimum amount of scratch space per node.")
97 args = parser.parse_args()
98
99 if args.node:
100 # assert args.num_nodes == 1
101 node_info = "#SBATCH -w {}".format(args.node)
102 else:
103 node_info = ""
104
105 job_name = "{}_{}".format(args.exp_name,
106 time.strftime("%m%d-%H%M%S", time.localtime()))
107 if str(args.ifray) == "false":
108 args.num_nodes = 1
109 args.cpus_per_node = 1
110 print('Number of nodes and cpus set to 1 due to ifray=false')
111
112 partition_option = "#SBATCH --partition={}".format(
113 args.partition) if args.partition else ""
114
115 memory = str(128000)
116 # ===== Modified the template script =====
117 with open(template_file, "r") as f:
118 text = f.read()
119 text = text.replace(JOB_NAME, job_name)
120 text = text.replace(NUM_NODES, str(args.num_nodes))
121 text = text.replace(NUM_GPUS_PER_NODE, str(args.num_gpus))
122 text = text.replace(PARTITION_OPTION, partition_option)
123 text = text.replace(COMMAND_PLACEHOLDER, str(args.command))
124 text = text.replace(LOAD_ENV, str(args.load_env))
125 text = text.replace(GIVEN_NODE, node_info)
126 text = text.replace(RUNTIME, args.runtime)
127 text = text.replace(ACCOUNT, args.account)
128 text = text.replace(MEMORY, memory)
129 text = text.replace(CPUS, str(args.cpus_per_node))
130 text = text.replace(SCRATCH, str(args.tmp))
131 text = text.replace(IFRAY, str(args.ifray))
132 text = text.replace(
133 "# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO "
134 "PRODUCTION!",
135 "# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
136 "RUNNABLE!")
137
138 # ===== Save the script =====
139 script_file = "bash_scripts/{}.sh".format(job_name)
140 with open(script_file, "w") as f:
141 f.write(text)
142
143 # ===== Submit the job =====
144 print("Starting to submit job!")
145 subprocess.Popen(["sbatch", script_file])
146 print(
147 "Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
148 script_file, "log_files/{}.log".format(job_name)))
149 sys.exit(0)