res_reg_lmnt_awikner.slurm-launch

  1# slurm-launch.py
  2# Usage:
  3# python slurm-launch.py --exp-name test \
  4#     --command "rllib train --run PPO --env CartPole-v0"
  5
  6import argparse
  7import subprocess
  8import sys
  9import time
 10
 11from pathlib import Path
 12
 13template_file = "/lustre/awikner1/res-noise-stabilization/slurm-template.sh"
 14JOB_NAME = "${JOB_NAME}"
 15NUM_NODES = "${NUM_NODES}"
 16NUM_GPUS_PER_NODE = "${NUM_GPUS_PER_NODE}"
 17PARTITION_OPTION = "${PARTITION_OPTION}"
 18COMMAND_PLACEHOLDER = "${COMMAND_PLACEHOLDER}"
 19GIVEN_NODE = "${GIVEN_NODE}"
 20LOAD_ENV = "${LOAD_ENV}"
 21RUNTIME = "${RUNTIME}"
 22ACCOUNT = "${ACCOUNT}"
 23MEMORY  = "${MEMORY}"
 24CPUS    = "${CPUS}"
 25PARTITION = "${PARTITION}"
 26SCRATCH = "${SCRATCH}"
 27IFRAY   = "${IFRAY}"
 28
 29if __name__ == "__main__":
 30    parser = argparse.ArgumentParser()
 31    parser.add_argument(
 32        "--ifray",
 33        "-r",
 34        type=str,
 35        default="true",
 36        help="Flag whether or not to use ray for parallelization.")
 37    parser.add_argument(
 38        "--runtime",
 39        "-t",
 40        type=str,
 41        default="15:00",
 42        help="Total runtime for the function.")
 43    parser.add_argument(
 44        "--account",
 45        "-A",
 46        type=str,
 47        default="physics-hi",
 48        help="Account to charge cluster time to.")
 49    parser.add_argument(
 50        "--exp-name",
 51        type=str,
 52        required=True,
 53        help="The job name and path to logging file (exp_name.log).")
 54    parser.add_argument(
 55        "--num-nodes",
 56        "-n",
 57        type=int,
 58        default=1,
 59        help="Number of nodes to use.")
 60    parser.add_argument(
 61        "--node",
 62        "-w",
 63        type=str,
 64        help="The specified nodes to use. Same format as the "
 65        "return of 'sinfo'. Default: ''.")
 66    parser.add_argument(
 67        "--num-gpus",
 68        type=int,
 69        default=0,
 70        help="Number of GPUs to use in each node. (Default: 0)")
 71    parser.add_argument(
 72        "--partition",
 73        "-p",
 74        type=str,
 75    )
 76    parser.add_argument(
 77        "--load-env",
 78        type=str,
 79        help="The script to load your environment ('module load cuda/10.1')")
 80    parser.add_argument(
 81        "--command",
 82        type=str,
 83        required=True,
 84        help="The command you wish to execute. For example: "
 85        " --command 'python test.py'. "
 86        "Note that the command must be a string.")
 87    parser.add_argument(
 88        "--cpus-per-node",
 89        type = int,
 90        default = 20,
 91        help = "The minimum number of CPUs per node.")
 92    parser.add_argument(
 93        "--tmp",
 94        type = int,
 95        default = 10240,
 96        help = "The minimum amount of scratch space per node.")
 97    args = parser.parse_args()
 98
 99    if args.node:
100        # assert args.num_nodes == 1
101        node_info = "#SBATCH -w {}".format(args.node)
102    else:
103        node_info = ""
104
105    job_name = "{}_{}".format(args.exp_name,
106                              time.strftime("%m%d-%H%M%S", time.localtime()))
107    if str(args.ifray) == "false":
108        args.num_nodes = 1
109        args.cpus_per_node = 1
110        print('Number of nodes and cpus set to 1 due to ifray=false')
111
112    partition_option = "#SBATCH --partition={}".format(
113        args.partition) if args.partition else ""
114
115    memory = str(128000)
116    # ===== Modified the template script =====
117    with open(template_file, "r") as f:
118        text = f.read()
119    text = text.replace(JOB_NAME, job_name)
120    text = text.replace(NUM_NODES, str(args.num_nodes))
121    text = text.replace(NUM_GPUS_PER_NODE, str(args.num_gpus))
122    text = text.replace(PARTITION_OPTION, partition_option)
123    text = text.replace(COMMAND_PLACEHOLDER, str(args.command))
124    text = text.replace(LOAD_ENV, str(args.load_env))
125    text = text.replace(GIVEN_NODE, node_info)
126    text = text.replace(RUNTIME, args.runtime)
127    text = text.replace(ACCOUNT, args.account)
128    text = text.replace(MEMORY, memory)
129    text = text.replace(CPUS, str(args.cpus_per_node))
130    text = text.replace(SCRATCH, str(args.tmp))
131    text = text.replace(IFRAY, str(args.ifray))
132    text = text.replace(
133        "# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO "
134        "PRODUCTION!",
135        "# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
136        "RUNNABLE!")
137
138    # ===== Save the script =====
139    script_file = "bash_scripts/{}.sh".format(job_name)
140    with open(script_file, "w") as f:
141        f.write(text)
142
143    # ===== Submit the job =====
144    print("Starting to submit job!")
145    subprocess.Popen(["sbatch", script_file])
146    print(
147        "Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
148            script_file, "log_files/{}.log".format(job_name)))
149    sys.exit(0)