#!/usr/bin/env python
# coding: utf-8
# In[ ]:
"""Defines a task to train a Machine Learning calculator on a trajectory of snapshots
by calling the train() function of the MLWrapper"""
# # Main Routine
# In[ ]:
from esteem.trajectories import merge_traj, get_trajectory_list
import sys
import os
import string
from shutil import copyfile
[docs]class MLTrainingTask:
def __init__(self,**kwargs):
self.wrapper = None
self.script_settings = None
self.task_command = 'mltrain'
self.train_params = {}
args = self.make_parser().parse_args("")
for arg in vars(args):
setattr(self,arg,getattr(args,arg))
# Main routine
[docs] def run(self):
"""Main routine for the ML_Training task"""
# Check input args are valid
#validate_args(args)
trajfn = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix=self.trajname_suffix)
trajstem = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix="")
all_trajs = get_trajectory_list(self.ntraj)
if self.which_trajs is None:
which_trajs = all_trajs
else:
which_trajs = self.which_trajs
for trajname in self.which_trajs:
if trajname not in all_trajs:
raise Exception("Invalid trajectory name: ",trajname)
# If we need an atom trajectory, copy it from trajname_suffix to calc_suffix:
if hasattr(self.wrapper,'atom_energies'):
atom_traj_file = f'{self.calc_prefix}{self.seed}_atoms_{self.trajname_suffix}.traj'
if os.path.isfile(atom_traj_file):
atom_calc_file = f'{self.calc_prefix}{self.seed}_atoms_{self.calc_suffix}.traj'
print(f'Copying from {atom_traj_file} to {atom_calc_file} for atom energies')
copyfile(atom_traj_file,atom_calc_file)
else:
raise Exception(f'Trajectory file {atom_traj_file} not found for atom energies')
# If all trajectories exist, merge them
trajnames = [trajstem + s + '_'+self.trajname_suffix+'.traj'
for s in which_trajs]
print(f'Trajectories to merge: {trajnames}',flush=True)
if all([os.path.isfile(f) for f in trajnames]):
if all([os.path.getsize(f) > 0 for f in trajnames]):
trajfile = f'{trajfn}_merged_{self.calc_suffix}.traj'
merge_traj(trajnames,trajfile)
# Train the ML calculator using this training data
calc = self.wrapper.train(seed=self.seed,trajfile=trajfile,target=self.target,
suffix=self.calc_suffix,restart=self.restart,**self.train_params)
else:
raise Exception('Empty Trajectory file(s) found: ',
[f for f in trajnames if os.path.getsize(f)==0])
else:
raise Exception('Missing Trajectory files: ',
[f for f in trajnames if not os.path.isfile(f)])
return calc
def make_parser(self):
import argparse
# Parse command line values
main_help = ('ML_Training.py: Train a ML-based Calculator from QMD trajectory files.')
epi_help = ('')
parser = argparse.ArgumentParser(description=main_help,epilog=epi_help)
parser.add_argument('--seed','-s',type=str,help='Base name stem for the calculation (often the name of the molecule)')
parser.add_argument('--calc_suffix','-S',default="",type=str,help='Suffix for the calculator (often specifies ML hyperparameters)')
parser.add_argument('--calc_prefix','-P',default="",type=str,help='Prefix for the calculator (often specifies directory)')
parser.add_argument('--target','-t',default=0,type=int,help='Excited state index, zero for ground state')
parser.add_argument('--trajname_suffix','-T',default='training',type=str,help='Suffix for the trajectory on which to train the calculator')
parser.add_argument('--energy_rmse','-E',default=0.02,type=float,help='RMS Energy deviation for convergence')
parser.add_argument('--force_rmse','-F',default=0.02,type=float,help='RMS Force deviation for convergence')
parser.add_argument('--energy_maxresid','-G',default=None,type=float,help='Maximum energy deviation for convergence')
parser.add_argument('--force_maxresid','-H',default=None,type=float,help='Maximum force deviation for convergence')
parser.add_argument('--hiddenlayers','-L',default=(10,10,10),nargs='*',type=int,help='Hidden Layer structure ')
parser.add_argument('--force_coefficient','-f',default=0.04,type=float,help='Weighting of forces (compared to energy) in training')
parser.add_argument('--overfit','-o',default=0.00,type=float,help='Weighting of forces (compared to energy) in training')
parser.add_argument('--ntraj','-n',default=1,type=int,help='How many total trajectories (A,B,C...) with this naming are present for training')
parser.add_argument('--restart','-r',default=False,nargs='?',const=True,type=bool,help='Whether to load a pre-existing calculator and resume training')
parser.add_argument('--cores','-c',default=1,type=int,help='Number of parallel cores on which to run the training')
parser.add_argument('--which-trajs','-w',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be trained against')
parser.add_argument('--steps','-A',default=None,type=int,help='Annealer steps')
parser.add_argument('--Tmax','-u',default=800.0,type=float,help='Annealer starting temperature')
parser.add_argument('--Tmin','-v',default=0.01,type=float,help='Annealer final temperature')
parser.add_argument('--cutoff','-d',default=6.5,type=float,help='Gaussian descriptor cutoff')
return parser
def validate_args(args):
default_args = make_parser().parse_args(['--seed','a'])
for arg in vars(args):
if arg not in default_args:
raise Exception(f"Unrecognised argument '{arg}'")
# # Command-line driver
# In[ ]:
if __name__ == '__main__':
from esteem import wrappers
mltrain = MLTrainingTask()
# Parse command line values
args = mltrain.make_parser().parse_args()
for arg in vars(args):
setattr(mltrain,arg,getattr(args,arg))
print('#',args)
mltrain.wrapper = wrappers.amp.AMPWrapper()
# Run main program
mltrain.run()