Skip to content

Commit

Permalink
Merge pull request #77 from fact-project/slurm
Browse files Browse the repository at this point in the history
Updates for the new cluster at isdc
  • Loading branch information
maxnoe authored Feb 12, 2019
2 parents 97daf62 + 05d7feb commit 10e1c9c
Show file tree
Hide file tree
Showing 20 changed files with 358 additions and 357 deletions.
3 changes: 3 additions & 0 deletions erna.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

erna_automatic_processing_executor
20 changes: 7 additions & 13 deletions erna/automatic_processing/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def main(config, verbose):
logging.getLogger('erna').setLevel(logging.DEBUG)

stream_handler = logging.StreamHandler()
file_handler = logging.FileHandler(config['submitter']['logfile'])
file_handler = logging.FileHandler(config['submitter'].pop('logfile'))
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s|%(levelname)s|%(name)s|%(message)s'
'%(asctime)s|%(levelname)s|%(message)s'
)

for handler in (stream_handler, file_handler):
Expand All @@ -44,16 +45,7 @@ def main(config, verbose):
database.close()

job_monitor = JobMonitor(port=config['submitter']['port'])
job_submitter = JobSubmitter(
interval=config['submitter']['interval'],
max_queued_jobs=config['submitter']['max_queued_jobs'],
data_directory=config['submitter']['data_directory'],
host=config['submitter']['host'],
port=config['submitter']['port'],
group=config['submitter']['group'],
mail_address=config['submitter']['mail_address'],
mail_settings=config['submitter']['mail_settings'],
)
job_submitter = JobSubmitter(**config['submitter'])

log.info('Starting main loop')
try:
Expand All @@ -68,13 +60,15 @@ def main(config, verbose):
job_submitter.terminate()
job_submitter.join()
log.info('Clean up running jobs')

database.connect()

queued = ProcessingState.get(description='queued')
running = ProcessingState.get(description='running')
inserted = ProcessingState.get(description='inserted')

for job in Job.select().where((Job.status == running) | (Job.status == queued)):
sp.run(['qdel', 'erna_{}'.format(job.id)])
sp.run(['scancel', '--jobname=erna_{}'.format(job.id)])
job.status = inserted
job.save()
database.close()
2 changes: 1 addition & 1 deletion erna/automatic_processing/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class NightField(Field):
db_field = 'night'
db_field = 'integer'

def db_value(self, value):
return date_to_night_int(value)
Expand Down
45 changes: 6 additions & 39 deletions erna/automatic_processing/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
Model, CharField, IntegerField, BooleanField,
ForeignKeyField, FixedCharField, TextField, MySQLDatabase
)
from playhouse.shortcuts import RetryOperationalError
import os
import logging
import wrapt

from .utils import parse_path
from .custom_fields import NightField, LongBlobField
Expand All @@ -14,7 +12,7 @@

__all__ = [
'RawDataFile', 'DrsFile',
'Jar', 'XML', 'Job', 'Queue',
'Jar', 'XML', 'Job',
'ProcessingState',
'database', 'setup_database',
]
Expand All @@ -31,22 +29,8 @@
'walltime_exceeded',
]

WALLTIMES = {
'fact_short': 60 * 60,
'fact_medium': 6 * 60 * 60,
'fact_long': 7 * 24 * 60 * 60,
}


class RetryMySQLDatabase(RetryOperationalError, MySQLDatabase):
''' Automatically reconnect when connection went down'''
pass


database = RetryMySQLDatabase(None, fields={
'night': 'INTEGER',
'longblob': 'LONGBLOB',
})
database = MySQLDatabase(None)


def setup_database(database, drop=False):
Expand All @@ -71,9 +55,6 @@ def setup_database(database, drop=False):
for description in PROCESSING_STATES:
ProcessingState.get_or_create(description=description)

for name, walltime in WALLTIMES.items():
Queue.get_or_create(name=name, walltime=walltime)


class File(Model):
night = NightField()
Expand All @@ -85,9 +66,9 @@ class Meta:
database = database
indexes = ((('night', 'run_id'), True), ) # unique index

def get_path(self):
def get_path(self, basepath='/fact/raw'):
return os.path.join(
'/fact/raw',
basepath,
str(self.night.year),
'{:02d}'.format(self.night.month),
'{:02d}'.format(self.night.day),
Expand Down Expand Up @@ -166,25 +147,16 @@ def __repr__(self):
return '{}'.format(self.description)


class Queue(Model):
name = CharField(unique=True)
walltime = IntegerField()

class Meta:
database = database
db_table = 'queues'


class Job(Model):
raw_data_file = ForeignKeyField(RawDataFile, related_name='raw_data_file')
drs_file = ForeignKeyField(DrsFile, related_name='drs_file')
jar = ForeignKeyField(Jar, related_name='jar')
result_file = CharField(null=True)
status = ForeignKeyField(ProcessingState, related_name='status')
priority = IntegerField(default=5)
walltime = IntegerField(default=180)
xml = ForeignKeyField(XML)
md5hash = FixedCharField(32, null=True)
queue = ForeignKeyField(Queue, related_name='queue')

class Meta:
database = database
Expand All @@ -193,10 +165,5 @@ class Meta:
(('raw_data_file', 'jar', 'xml'), True), # unique constraint
)

MODELS = [RawDataFile, DrsFile, Jar, XML, Job, ProcessingState, Queue]


@wrapt.decorator
def requires_database_connection(wrapped, instance, args, kwargs):
database.get_conn()
return wrapped(*args, **kwargs)
MODELS = [RawDataFile, DrsFile, Jar, XML, Job, ProcessingState]
49 changes: 21 additions & 28 deletions erna/automatic_processing/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .database import (
RawDataFile, DrsFile, Job,
ProcessingState, Jar, XML,
requires_database_connection
database
)


Expand All @@ -19,7 +19,7 @@
]


@requires_database_connection
@database.connection_context()
def fill_data_runs(df, database):
if len(df) == 0:
return
Expand All @@ -46,12 +46,11 @@ def fill_data_runs(df, database):
database.execute_sql(sql, params=params)


@requires_database_connection
@database.connection_context()
def fill_drs_runs(df, database):
if len(df) == 0:
return
df = df.copy()
print(df.columns)
df.drop(['fRunTypeKey', 'fRunTypeName'], axis=1, inplace=True)
df.rename(
columns={
Expand All @@ -74,7 +73,7 @@ def fill_drs_runs(df, database):
database.execute_sql(sql, params=params)


@requires_database_connection
@database.connection_context()
def get_pending_jobs(limit=None):
runs = (
Job
Expand All @@ -90,7 +89,7 @@ def get_pending_jobs(limit=None):
return runs


@requires_database_connection
@database.connection_context()
def find_drs_file(raw_data_file, closest=True):
'''
Find a drs file for the give raw data file.
Expand All @@ -99,7 +98,6 @@ def find_drs_file(raw_data_file, closest=True):
'''
query = DrsFile.select()
query = query.where(DrsFile.night == raw_data_file.night)
query = query.where(DrsFile.available)

if raw_data_file.roi == 300:
query = query.where((DrsFile.drs_step == 2) & (DrsFile.roi == 300))
Expand All @@ -124,12 +122,12 @@ def find_drs_file(raw_data_file, closest=True):
return drs_file


@requires_database_connection
@database.connection_context()
def insert_new_job(
raw_data_file,
jar,
xml,
queue,
walltime,
priority=5,
closest_drs_file=True,
):
Expand All @@ -144,8 +142,8 @@ def insert_new_job(
the fact-tools jar to use
xml: XML
the xml to use
queue: Queue
the queue to use
walltime: walltime
the walltime to use
priority: int
Priority for the Job. Lower numbers mean more important.
closest_drs_file: bool
Expand All @@ -169,7 +167,7 @@ def insert_new_job(
raw_data_file=raw_data_file,
drs_file=drs_file,
jar=jar,
queue=queue,
walltime=walltime,
status=ProcessingState.get(description='inserted'),
priority=priority,
xml=xml,
Expand All @@ -178,8 +176,8 @@ def insert_new_job(
job.save()


@requires_database_connection
def insert_new_jobs(raw_data_files, jar, xml, queue, progress=True, **kwargs):
@database.connection_context()
def insert_new_jobs(raw_data_files, jar, xml, walltime, progress=True, **kwargs):

if isinstance(raw_data_files, list):
total = len(raw_data_files)
Expand All @@ -189,7 +187,7 @@ def insert_new_jobs(raw_data_files, jar, xml, queue, progress=True, **kwargs):
failed_files = []
for f in tqdm(raw_data_files, total=total, disable=not progress):
try:
insert_new_job(f, jar=jar, xml=xml, queue=queue, **kwargs)
insert_new_job(f, jar=jar, xml=xml, walltime=walltime, **kwargs)
except peewee.IntegrityError:
log.warning('Job already submitted: {}_{:03d}'.format(f.night, f.run_id))
except ValueError as e:
Expand All @@ -200,7 +198,7 @@ def insert_new_jobs(raw_data_files, jar, xml, queue, progress=True, **kwargs):
return failed_files


@requires_database_connection
@database.connection_context()
def count_jobs(state=None):
query = Job.select()

Expand All @@ -211,7 +209,7 @@ def count_jobs(state=None):
return query.count()


@requires_database_connection
@database.connection_context()
def save_xml(xml_id, data_dir):
if not os.path.exists(data_dir):
os.makedirs(data_dir)
Expand All @@ -234,7 +232,7 @@ def save_xml(xml_id, data_dir):
return xml_file


@requires_database_connection
@database.connection_context()
def save_jar(jar_id, data_dir):
if not os.path.exists(data_dir):
os.makedirs(data_dir)
Expand All @@ -256,7 +254,7 @@ def save_jar(jar_id, data_dir):
return jar_file


@requires_database_connection
@database.connection_context()
def build_output_directory_name(job, output_base_dir):
version = Jar.select(Jar.version).where(Jar.id == job.jar_id).get().version
return os.path.join(
Expand All @@ -269,7 +267,7 @@ def build_output_directory_name(job, output_base_dir):
)


@requires_database_connection
@database.connection_context()
def build_output_base_name(job):
version = Jar.select(Jar.version).where(Jar.id == job.jar_id).get().version
return '{night:%Y%m%d}_{run_id:03d}_{version}_{name}'.format(
Expand All @@ -280,18 +278,13 @@ def build_output_base_name(job):
)


@requires_database_connection
def resubmit_walltime_exceeded(old_queue, new_queue):
@database.connection_context()
def resubmit_walltime_exceeded(factor=1.5):
'''
Resubmit jobs where walltime was exceeded.
Change queue from old_queue to new_queue
'''
if old_queue.walltime >= new_queue.walltime:
raise ValueError('New queue must have longer walltime for this to make sense')

return (
Job
.update(queue=new_queue, status=ProcessingState.get(description='inserted'))
.update(walltime=factor * Job.walltime, status=ProcessingState.get(description='inserted'))
.where(Job.status == ProcessingState.get(description='walltime_exceeded'))
.where(Job.queue == old_queue)
).execute()
17 changes: 10 additions & 7 deletions erna/automatic_processing/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
start_time = time.perf_counter()
start_time = time.monotonic()

import subprocess as sp
import os
Expand All @@ -17,8 +17,11 @@
socket = context.socket(zmq.REQ)

log = logging.getLogger('erna')
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
log.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
fmt = logging.Formatter(fmt='%(asctime)s [%(levelname)-8s] %(message)s')
handler.setFormatter(fmt)
logging.getLogger().addHandler(handler)


def main():
Expand All @@ -28,7 +31,7 @@ def main():
port = os.environ['SUBMITTER_PORT']
socket.connect('tcp://{}:{}'.format(host, port))

job_id = int(os.environ['JOB_NAME'].replace('erna_', ''))
job_id = int(os.environ['SLURM_JOB_NAME'].replace('erna_', ''))

socket.send_pyobj({'job_id': job_id, 'status': 'running'})
socket.recv()
Expand Down Expand Up @@ -77,8 +80,8 @@ def main():
sp.run(['free', '-m'], check=True)
sp.run([java, '-Xmx512m', '-version'], check=True)

log.info('Calling fact-tools with call: {}'.format(call))
timeout = walltime - (time.perf_counter() - start_time) - 300
log.info('Calling fact-tools with call: "{}"'.format(' '.join(call)))
timeout = walltime - (time.monotonic() - start_time) - 300
log.info('Setting fact-tools timout to %.0f', timeout)
sp.run(call, cwd=tmp_dir, check=True, timeout=timeout)
except sp.CalledProcessError:
Expand All @@ -88,7 +91,7 @@ def main():
sys.exit(1)
except sp.TimeoutExpired:
socket.send_pyobj({'job_id': job_id, 'status': 'walltime_exceeded'})
log.exception('FACT Tools about to run into wall-time, terminating')
log.error('FACT Tools about to run into wall-time, terminating')
socket.recv()
sys.exit(1)

Expand Down
Loading

0 comments on commit 10e1c9c

Please sign in to comment.