Source code for esteem.tasks.ml_testing

#!/usr/bin/env python
# coding: utf-8

# In[ ]:


"""Defines a task to test a Machine Learning calculator by comparing its results
to those of an existing trajectory or set of trajectories"""


# # Main Routine

# In[ ]:


# Load essential modules
import sys
import os
import string
from ase.io.trajectory import Trajectory

from esteem.trajectories import merge_traj, get_trajectory_list

[docs]class MLTestingTask: def __init__(self,**kwargs): self.wrapper = None self.script_settings = None self.task_command = 'mltest' self.train_params = {} args = self.make_parser().parse_args("") for arg in vars(args): setattr(self,arg,getattr(args,arg)) # Main routine
[docs] def run(self): """Main routine for the ML_Testing task""" # Check input args are valid #validate_args(self) # Get strings for trajectory names if self.calc_seed is None: self.calc_seed = self.seed trajfn = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix=self.trajname_suffix) trajstem = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix="") all_trajs = get_trajectory_list(self.ntraj) trajname_suffix = self.trajname_suffix if self.which_trajs is None: which_trajs = all_trajs else: which_trajs = self.which_trajs if "_" in which_trajs: trajname_suffix = which_trajs.split("_",1)[1] which_trajs = which_trajs.split("_",1)[0] #for trajname in which_trajs: # if trajname not in all_trajs: # raise Exception("Invalid trajectory name:",trajname) # If all trajectories exist, test against them trajnames = [trajstem + s + '_' + trajname_suffix + '.traj' for s in which_trajs] print('# Merging trajectories: ',trajnames) if not all([os.path.isfile(f) for f in trajnames]): raise Exception('# Missing Trajectory file(s): ', [f for f in trajnames if not os.path.isfile(f)]) if not all([os.path.getsize(f) > 0 for f in trajnames]): raise Exception('# Empty Trajectory file(s) found: ', [f for f in trajnames if os.path.getsize(f)==0]) # Test the ML calculator against the results in the merged trajectory intrajfile = trajfn+"_"+self.calc_suffix+"_"+''.join(which_trajs)+'_merged.traj' if os.path.isfile(intrajfile): print(f'# Merged input trajectory file {intrajfile} already exists. Overwriting!') else: print(f'# Writing merged input trajectory file {intrajfile}') merge_traj(trajnames,intrajfile) intraj = Trajectory(intrajfile) # Load the calculator print("# Reading Calculator") if hasattr(self.wrapper,'load'): calc = self.wrapper.load(self.calc_seed,self.target, prefix=self.calc_prefix, suffix=self.calc_suffix) # See if an energy offset is required e_offset = 0.0 if hasattr(self.wrapper,'atom_energies'): e_offset = self.wrapper.atom_energies output_traj = self.output_traj if output_traj is None: output_traj = self.calc_suffix + "_"+''.join(which_trajs)+"_test" outtrajfile = trajfn+"_"+output_traj+'.traj' if os.path.isfile(outtrajfile): print(f'# Warning: output trajectory file {outtrajfile} already exists. Overwriting!') else: print(f'# Writing output trajectory file {outtrajfile}') compare_calc_to_traj(calc,intraj,outtrajfile,self.plotfile,e_offset) if False: # optional cleanup os.remove(intrajfile) os.remove(outtrajfile)
# Generate default arguments and return as parser def make_parser(self): import argparse main_help = ('ML_Testing.py: Test an ML-based Calculator by comparing it to trajectories.') epi_help = ('') parser = argparse.ArgumentParser(description=main_help,epilog=epi_help) parser.add_argument('--seed','-s',type=str,help='Base name stem for the calculation (often the name of the molecule)') parser.add_argument('--calc_seed','-Z',default=None,type=str,help='Seed name for the calculator') parser.add_argument('--calc_suffix','-S',default="",type=str,help='Suffix for the calculator (often specifies ML hyperparameters)') parser.add_argument('--calc_prefix','-P',default="",type=str,help='Prefix for the calculator (often specifies directory)') parser.add_argument('--target','-t',default=0,type=int,help='Excited state index, zero for ground state') parser.add_argument('--output_traj','-o',default=None,type=str,help='Filename to which to write the calculated trajectory') parser.add_argument('--plotfile','-p',default=None,nargs='?',const="TkAgg",type=str,help='Image file to which to write comparison plot') parser.add_argument('--trajname_suffix','-T',default='training',type=str,help='Suffix for the trajectory files being tested') parser.add_argument('--ntraj','-n',default=1,type=int,help='How many total trajectories (A,B,C...) with this naming are present') parser.add_argument('--which-trajs','-w',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be tested') return parser def validate_args(args): default_args = make_parser().parse_args(['--seed','a']) for arg in vars(args): if arg not in default_args: raise Exception(f"Unrecognised argument '{arg}'")
# # Comparison of results from trajectory to calculator # In[1]: import numpy as np from esteem.trajectories import compare_traj_to_traj,atom_energy
[docs]def compare_calc_to_traj(calc,trajin,trajout_file,plot_file=None,e_offset=0.0): """Compare the energy and force predictions of a calculator to results in an existing trajectory""" trajout = Trajectory(trajout_file,'w') # Loop over the frames in the input trajectory for i,frame in enumerate(trajin): # Read in total energy and forces from trajectory e_traj = frame.get_potential_energy() f_traj = frame.get_forces() d_traj = frame.get_dipole_moment() err=False frame.set_calculator(calc) # May need to set up a box, for some calculators (TODO: when?) frame.center(20) # Calculate total energy from calculator e_calc = frame.get_potential_energy() # If a list of float32s has been returned, convert to a single float64 if isinstance(e_calc,np.ndarray): e_calc = np.float64(e_calc[0]) # Apply offset energy if isinstance(e_offset,float): e_calc += e_offset elif isinstance(e_offset,dict): atom_e = atom_energy(frame,e_offset) e_calc += atom_e # Attempt to store the offset energy in the results array try: frame.calc.results["energy"] = e_calc except: frame.calc._last_energy = e_calc # Calculate forces from calculator f_calc = -1 try: f_calc = frame.get_forces() except KeyboardInterrupt: raise Exception('Keyboard Interrupt') except: print('# Error calculating force for frame %4d' % i) err = True # Calculate RMS and Max force errors rms_fd = np.sqrt(np.mean((f_traj-f_calc)**2)) max_fd = np.max(np.sqrt((f_traj-f_calc)**2)) print('%4d %12.5f %12.5f %12.8f %8.5f %8.5f' % (i, e_traj, e_calc, e_traj-e_calc, rms_fd,max_fd)) # Calculate charges and dipole moment try: q_calc = frame.get_charges() d_calc = frame.get_dipole_moment() except: print('# Error calculating charges and dipole moment for frame %4d' % i) err = True # Assemble dictionary of properties to write to the output trajectory kw = {'dipole': d_calc, 'charges': q_calc, 'energy': e_calc, 'forces': f_calc} # Write to trajectory trajout.write(frame,**kw) trajout.close() # Open temporary trajectory for reading tmptraj = Trajectory(trajout_file) compare_traj_to_traj(trajin,tmptraj,plot_file,'Trajectory Energy (eV)','Calculator Energy (eV)')
# # Command-line driver # In[ ]: if __name__ == '__main__': from esteem import wrappers mltest = MLTestingTask() # Parse command line values args = mltrain.make_parser().parse_args() for arg in vars(args): setattr(mltrain,arg,getattr(args,arg)) print('#',args) mltest.wrapper = wrappers.amp.AMPWrapper() # Run main program mltest.run()