Source code for esteem.tasks.ml_training

#!/usr/bin/env python
# coding: utf-8

# In[ ]:


"""Defines a task to train a Machine Learning calculator on a trajectory of snapshots
by calling the train() function of the MLWrapper"""


# # Main Routine

# In[ ]:


from esteem.trajectories import merge_traj, diff_traj, get_trajectory_list, targstr

import sys
import os
import string
from shutil import copyfile

[docs]class MLTrainingTask: def __init__(self,**kwargs): self.wrapper = None self.script_settings = None self.task_command = 'mltrain' self.train_params = {} args = self.make_parser().parse_args("") for arg in vars(args): setattr(self,arg,getattr(args,arg)) def get_trajnames(self,prefix=""): all_trajs = get_trajectory_list(self.ntraj) if 'valid' in prefix: which_trajs = self.which_trajs_valid else: which_trajs = self.which_trajs if which_trajs is None: which_trajs = all_trajs else: for trajname in which_trajs: if trajname not in all_trajs: raise Exception("Invalid trajectory name: ",trajname) trajstem = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix="") trajnames = [f'{trajstem}{s}_{self.traj_suffix}.traj' for s in which_trajs] return which_trajs,trajnames # Main routine
[docs] def run(self): """Main routine for the ML_Training task""" # Check input args are valid #validate_args(args) trajfn = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix=self.traj_suffix) # If we need an atom trajectory, copy it from traj_suffix to calc_suffix: if hasattr(self.wrapper,'atom_energies'): atom_traj_file = f'{self.seed}_atoms_{self.traj_suffix}.traj' if os.path.isfile(atom_traj_file): atom_calc_file = f'{self.seed}_atoms_{self.calc_suffix}.traj' print(f'# Copying from {atom_traj_file} to {atom_calc_file} for atom energies') copyfile(atom_traj_file,atom_calc_file) else: raise Exception(f'# Trajectory file {atom_traj_file} not found for atom energies') # If we are training on energy differences, calculate these now prefs = [""] if self.which_trajs_valid is not None: prefs = ["","valid"] for prefix in prefs: if False: #'diff' in self.target: which_trajs, trajnames = self.get_trajnames(prefix) itarget = 0 jtarget = 1 for traj in trajnames: itraj = traj.replace("diff",targstr(itarget)) jtraj = traj.replace("diff",targstr(jtarget)) print('# Calling diff_traj with {itraj} {jtraj} {traj}') diff_traj(itraj,jtraj,traj) # If all trajectories exist, merge them which_trajs, trajnames = self.get_trajnames(prefix) print(f'# Trajectories to merge: {trajnames}',flush=True) if all([os.path.isfile(f) for f in trajnames]): if all([os.path.getsize(f) > 0 for f in trajnames]): if prefix=="": trajfile = f'{trajfn}_{prefix}merged_{self.calc_suffix}.traj' merge_traj(trajnames,trajfile) if prefix=="valid": validfile = f'{trajfn}_{prefix}merged_{self.calc_suffix}.traj' merge_traj(trajnames,validfile) else: validfile=None else: raise Exception('# Empty Trajectory file(s) found: ', [f for f in trajnames if os.path.getsize(f)==0]) else: raise Exception('# Missing Trajectory files: ', [f for f in trajnames if not os.path.isfile(f)]) if self.reset_loss: if hasattr(self.wrapper,"reset_loss"): self.wrapper.reset_loss(seed=self.seed,target=self.target, prefix=self.calc_prefix,suffix=self.calc_suffix,) else: raise Exception("# Error: reset_loss == True, yet wrapper has no reset_loss function") # Train the ML calculator using this training data calc = self.wrapper.train(seed=self.seed,trajfile=trajfile,validfile=validfile,target=self.target, prefix=self.calc_prefix,suffix=self.calc_suffix,dir_suffix=self.calc_dir_suffix, restart=self.restart,**self.train_params) return calc
def make_parser(self): import argparse # Parse command line values main_help = ('ML_Training.py: Train a ML-based Calculator from QMD trajectory files.') epi_help = ('') parser = argparse.ArgumentParser(description=main_help,epilog=epi_help) parser.add_argument('--seed','-s',type=str,help='Base name stem for the calculation (often the name of the molecule)') parser.add_argument('--calc_suffix','-S',default="",type=str,help='Suffix for the calculator') parser.add_argument('--calc_dir_suffix','-D',default=None,type=str,help='Suffix for the calculator directory ') parser.add_argument('--calc_prefix','-P',default="",type=str,help='Prefix for the calculator (often specifies directory)') parser.add_argument('--target','-t',default=0,type=int,help='Excited state index, zero for ground state') parser.add_argument('--traj_prefix','-Q',default='training',type=str,help='Prefix for the trajectory on which to train the calculator') parser.add_argument('--traj_suffix','-T',default='training',type=str,help='Suffix for the trajectory on which to train the calculator') parser.add_argument('--geom_prefix',default='gs_PBE0/is_opt_{solv}',nargs='?',type=str,help='Prefix for the path at which to find the input geometry') parser.add_argument('--ntraj','-n',default=1,type=int,help='How many total trajectories (A,B,C...) with this naming are present for training') parser.add_argument('--restart','-r',default=False,nargs='?',const=True,type=bool,help='Whether to load a pre-existing calculator and resume training') parser.add_argument('--reset_loss','-R',default=False,nargs='?',const=True,type=bool,help='Whether to reset the loss function due to new training data being added') parser.add_argument('--which_trajs','-w',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be trained against') parser.add_argument('--which_trajs_valid','-v',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be validated against') parser.add_argument('--which_trajs_test','-u',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be tested against') parser.add_argument('--traj_links','-L',default=None,type=dict,help='Targets for links to create for training trajectories') parser.add_argument('--traj_links_valid','-V',default=None,type=dict,help='Targets for links to create for validation trajectories') parser.add_argument('--traj_links_test','-U',default=None,type=dict,help='Targets for links to create for testing trajectories') parser.add_argument('--cutoff','-d',default=6.5,type=float,help='Gaussian descriptor cutoff') ''' parser.add_argument('--cores','-c',default=1,type=int,help='Number of parallel cores on which to run the training') parser.add_argument('--steps','-A',default=None,type=int,help='Annealer steps') parser.add_argument('--Tmax','-u',default=800.0,type=float,help='Annealer starting temperature') parser.add_argument('--Tmin','-v',default=0.01,type=float,help='Annealer final temperature') parser.add_argument('--energy_rmse','-E',default=0.02,type=float,help='RMS Energy deviation for convergence') parser.add_argument('--force_rmse','-F',default=0.02,type=float,help='RMS Force deviation for convergence') parser.add_argument('--energy_maxresid','-G',default=None,type=float,help='Maximum energy deviation for convergence') parser.add_argument('--force_maxresid','-H',default=None,type=float,help='Maximum force deviation for convergence') parser.add_argument('--hiddenlayers','-L',default=(10,10,10),nargs='*',type=int,help='Hidden Layer structure ') parser.add_argument('--force_coefficient','-f',default=0.04,type=float,help='Weighting of forces (compared to energy) in training') parser.add_argument('--overfit','-o',default=0.00,type=float,help='Weighting of forces (compared to energy) in training') ''' return parser def validate_args(args): default_args = make_parser().parse_args(['--seed','a']) for arg in vars(args): if arg not in default_args: raise Exception(f"Unrecognised argument '{arg}'")
# # Command-line driver # In[ ]: def get_parser(): mltrain = MLTrainingTask() return mltrain.make_parser() if __name__ == '__main__': from esteem import wrappers mltrain = MLTrainingTask() # Parse command line values args = mltrain.make_parser().parse_args() for arg in vars(args): setattr(mltrain,arg,getattr(args,arg)) print('#',args) mltrain.wrapper = wrappers.amp.AMPWrapper() # Run main program mltrain.run()