#!/usr/bin/env python
# coding: utf-8
# In[ ]:
"""Defines a task to test a Machine Learning calculator by comparing its results
to those of an existing trajectory or set of trajectories"""
# # Main Routine
# In[ ]:
# Load essential modules
import sys
import os
import string
from ase.io.trajectory import Trajectory
from esteem.trajectories import merge_traj, get_trajectory_list
[docs]class MLTestingTask:
def __init__(self,**kwargs):
self.wrapper = None
self.script_settings = None
self.task_command = 'mltest'
self.train_params = {}
args = self.make_parser().parse_args("")
for arg in vars(args):
setattr(self,arg,getattr(args,arg))
# Main routine
[docs] def run(self):
"""Main routine for the ML_Testing task"""
# Check input args are valid
#validate_args(self)
# Get strings for trajectory names
if self.calc_seed is None:
self.calc_seed = self.seed
trajfn = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix=self.trajname_suffix)
trajstem = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix="")
all_trajs = get_trajectory_list(self.ntraj)
trajname_suffix = self.trajname_suffix
if self.which_trajs is None:
which_trajs = all_trajs
else:
which_trajs = self.which_trajs
if "_" in which_trajs:
trajname_suffix = which_trajs.split("_",1)[1]
which_trajs = which_trajs.split("_",1)[0]
#for trajname in which_trajs:
# if trajname not in all_trajs:
# raise Exception("Invalid trajectory name:",trajname)
# If all trajectories exist, test against them
trajnames = [trajstem + s + '_' + trajname_suffix + '.traj'
for s in which_trajs]
print('# Merging trajectories: ',trajnames)
if not all([os.path.isfile(f) for f in trajnames]):
raise Exception('# Missing Trajectory file(s): ',
[f for f in trajnames if not os.path.isfile(f)])
if not all([os.path.getsize(f) > 0 for f in trajnames]):
raise Exception('# Empty Trajectory file(s) found: ',
[f for f in trajnames if os.path.getsize(f)==0])
# Test the ML calculator against the results in the merged trajectory
intrajfile = trajfn+"_"+self.calc_suffix+"_"+''.join(which_trajs)+'_merged.traj'
if os.path.isfile(intrajfile):
print(f'# Merged input trajectory file {intrajfile} already exists. Overwriting!')
else:
print(f'# Writing merged input trajectory file {intrajfile}')
merge_traj(trajnames,intrajfile)
intraj = Trajectory(intrajfile)
# Load the calculator
print("# Reading Calculator")
if hasattr(self.wrapper,'load'):
calc = self.wrapper.load(self.calc_seed,self.target,
prefix=self.calc_prefix,
suffix=self.calc_suffix)
# See if an energy offset is required
e_offset = 0.0
if hasattr(self.wrapper,'atom_energies'):
e_offset = self.wrapper.atom_energies
output_traj = self.output_traj
if output_traj is None:
output_traj = self.calc_suffix + "_"+''.join(which_trajs)+"_test"
outtrajfile = trajfn+"_"+output_traj+'.traj'
if os.path.isfile(outtrajfile):
print(f'# Warning: output trajectory file {outtrajfile} already exists. Overwriting!')
else:
print(f'# Writing output trajectory file {outtrajfile}')
compare_calc_to_traj(calc,intraj,outtrajfile,self.plotfile,e_offset)
if False: # optional cleanup
os.remove(intrajfile)
os.remove(outtrajfile)
# Generate default arguments and return as parser
def make_parser(self):
import argparse
main_help = ('ML_Testing.py: Test an ML-based Calculator by comparing it to trajectories.')
epi_help = ('')
parser = argparse.ArgumentParser(description=main_help,epilog=epi_help)
parser.add_argument('--seed','-s',type=str,help='Base name stem for the calculation (often the name of the molecule)')
parser.add_argument('--calc_seed','-Z',default=None,type=str,help='Seed name for the calculator')
parser.add_argument('--calc_suffix','-S',default="",type=str,help='Suffix for the calculator (often specifies ML hyperparameters)')
parser.add_argument('--calc_prefix','-P',default="",type=str,help='Prefix for the calculator (often specifies directory)')
parser.add_argument('--target','-t',default=0,type=int,help='Excited state index, zero for ground state')
parser.add_argument('--output_traj','-o',default=None,type=str,help='Filename to which to write the calculated trajectory')
parser.add_argument('--plotfile','-p',default=None,nargs='?',const="TkAgg",type=str,help='Image file to which to write comparison plot')
parser.add_argument('--trajname_suffix','-T',default='training',type=str,help='Suffix for the trajectory files being tested')
parser.add_argument('--ntraj','-n',default=1,type=int,help='How many total trajectories (A,B,C...) with this naming are present')
parser.add_argument('--which-trajs','-w',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be tested')
return parser
def validate_args(args):
default_args = make_parser().parse_args(['--seed','a'])
for arg in vars(args):
if arg not in default_args:
raise Exception(f"Unrecognised argument '{arg}'")
# # Comparison of results from trajectory to calculator
# In[1]:
import numpy as np
from esteem.trajectories import compare_traj_to_traj,atom_energy
[docs]def compare_calc_to_traj(calc,trajin,trajout_file,plot_file=None,e_offset=0.0):
"""Compare the energy and force predictions of a calculator to results in an existing trajectory"""
trajout = Trajectory(trajout_file,'w')
# Loop over the frames in the input trajectory
for i,frame in enumerate(trajin):
# Read in total energy and forces from trajectory
e_traj = frame.get_potential_energy()
f_traj = frame.get_forces()
d_traj = frame.get_dipole_moment()
err=False
frame.set_calculator(calc)
# May need to set up a box, for some calculators (TODO: when?)
frame.center(20)
# Calculate total energy from calculator
e_calc = frame.get_potential_energy()
# If a list of float32s has been returned, convert to a single float64
if isinstance(e_calc,np.ndarray):
e_calc = np.float64(e_calc[0])
# Apply offset energy
if isinstance(e_offset,float):
e_calc += e_offset
elif isinstance(e_offset,dict):
atom_e = atom_energy(frame,e_offset)
e_calc += atom_e
# Attempt to store the offset energy in the results array
try:
frame.calc.results["energy"] = e_calc
except:
frame.calc._last_energy = e_calc
# Calculate forces from calculator
f_calc = -1
try:
f_calc = frame.get_forces()
except KeyboardInterrupt:
raise Exception('Keyboard Interrupt')
except:
print('# Error calculating force for frame %4d' % i)
err = True
# Calculate RMS and Max force errors
rms_fd = np.sqrt(np.mean((f_traj-f_calc)**2))
max_fd = np.max(np.sqrt((f_traj-f_calc)**2))
print('%4d %12.5f %12.5f %12.8f %8.5f %8.5f' %
(i, e_traj, e_calc, e_traj-e_calc, rms_fd,max_fd))
# Calculate charges and dipole moment
try:
q_calc = frame.get_charges()
d_calc = frame.get_dipole_moment()
except:
print('# Error calculating charges and dipole moment for frame %4d' % i)
err = True
# Assemble dictionary of properties to write to the output trajectory
kw = {'dipole': d_calc,
'charges': q_calc,
'energy': e_calc,
'forces': f_calc}
# Write to trajectory
trajout.write(frame,**kw)
trajout.close()
# Open temporary trajectory for reading
tmptraj = Trajectory(trajout_file)
compare_traj_to_traj(trajin,tmptraj,plot_file,'Trajectory Energy (eV)','Calculator Energy (eV)')
# # Command-line driver
# In[ ]:
if __name__ == '__main__':
from esteem import wrappers
mltest = MLTestingTask()
# Parse command line values
args = mltrain.make_parser().parse_args()
for arg in vars(args):
setattr(mltrain,arg,getattr(args,arg))
print('#',args)
mltest.wrapper = wrappers.amp.AMPWrapper()
# Run main program
mltest.run()