#!/usr/bin/env python
# coding: utf-8
# In[ ]:
"""Defines a task to test a Machine Learning calculator by comparing its results
to those of an existing trajectory or set of trajectories"""
# # Main Routine
# In[ ]:
# Load essential modules
import sys
import os
import string
from ase.io.trajectory import Trajectory
from esteem.trajectories import merge_traj, get_trajectory_list
[docs]class MLTestingTask:
def __init__(self,**kwargs):
self.wrapper = None
self.script_settings = None
self.task_command = 'mltest'
self.train_params = {}
args = self.make_parser().parse_args("")
for arg in vars(args):
setattr(self,arg,getattr(args,arg))
def get_trajnames(self):
all_trajs = get_trajectory_list(self.ntraj)
if self.which_trajs is None:
which_trajs = all_trajs
else:
which_trajs = self.which_trajs
if isinstance(which_trajs,dict):
which_trajs_dict = which_trajs
which_trajs = []
for w in which_trajs_dict:
for v in which_trajs_dict[w]:
which_trajs.append(v)
for trajname in which_trajs:
if trajname not in all_trajs:
raise Exception("Invalid trajectory name: ",trajname)
trajstem = self.wrapper.calc_filename(self.seed,self.target,prefix=self.calc_prefix,suffix="")
trajnames = [trajstem + s + '_'+self.traj_suffix+'.traj'
for s in which_trajs]
return which_trajs,trajnames
# Main routine
[docs] def run(self):
"""Main routine for the ML_Testing task"""
from os.path import commonprefix
# Check input args are valid
#validate_args(self)
# Get strings for trajectory names
if self.calc_seed is None:
self.calc_seed = self.seed
traj_prefix = self.traj_prefix
# Assume path is relative to base directory, unless blank
if traj_prefix!="":
traj_prefix = "../"+traj_prefix
trajfn = self.wrapper.calc_filename(self.calc_seed,self.target,prefix="",suffix=self.traj_suffix)
trajstem = self.wrapper.calc_filename(self.calc_seed,self.target,prefix=traj_prefix,suffix="")
all_trajs = get_trajectory_list(self.ntraj)
traj_suffix = self.traj_suffix
if self.which_trajs is None:
which_trajs = all_trajs
else:
which_trajs = self.which_trajs
if "_" in which_trajs:
traj_suffix = which_trajs.split("_",1)[1]
which_trajs = which_trajs.split("_",1)[0]
#for trajname in which_trajs:
# if trajname not in all_trajs:
# raise Exception("Invalid trajectory name:",trajname)
# If all trajectories exist, test against them
trajnames = [trajstem + s + '_' + traj_suffix + '.traj'
for s in which_trajs]
print('# Merging trajectories: ',trajnames)
if not all([os.path.isfile(f) for f in trajnames]):
raise Exception('# Missing Trajectory file(s): ',
[f for f in trajnames if not os.path.isfile(f)])
if not all([os.path.getsize(f) > 0 for f in trajnames]):
raise Exception('# Empty Trajectory file(s) found: ',
[f for f in trajnames if os.path.getsize(f)==0])
# Test the ML calculator against the results in the merged trajectory
which_traj_str = ''.join(which_trajs)
if len(which_traj_str) > 20:
which_traj_str = which_traj_str[0:20]
if isinstance(self.calc_suffix,dict):
calc_suffix = commonprefix(list(self.calc_suffix.keys()))
else:
calc_suffix = self.calc_suffix
#test_dir = f'{seed_state_str}_{calc_suffix}_test'
intrajfile = trajfn+"_"+calc_suffix+"_"+which_traj_str+'_merged.traj'
if os.path.isfile(intrajfile):
print(f'# Merged input trajectory file {intrajfile} already exists. Overwriting!')
else:
print(f'# Writing merged input trajectory file {intrajfile}')
merge_traj(trajnames,intrajfile)
intraj = Trajectory(intrajfile)
# Load the calculator
print("# Loading Calculator")
calc_params = {'calc_seed': self.calc_seed,
'calc_suffix': self.calc_suffix,
'calc_dir_suffix': self.calc_dir_suffix,
'calc_prefix': f'../{self.calc_prefix}', # Testing will be run from subdirectory
'target': self.target}
if hasattr(self.wrapper,'update_atom_e'):
self.wrapper.update_atom_e = True
output_traj = self.output_traj
if output_traj is None:
output_traj = calc_suffix + "_"+which_traj_str+"_test"
outtrajfile = trajfn+"_"+output_traj+'.traj'
if os.path.isfile(outtrajfile):
print(f'# Warning: output trajectory file {outtrajfile} already exists. Overwriting!')
else:
print(f'# Writing output trajectory file {outtrajfile}')
compare_wrapper_to_traj(self.wrapper,calc_params,intraj,outtrajfile)
# Open temporary trajectory for reading
outtraj = Trajectory(outtrajfile)
if self.ref_mol_dir is not None:
intrajfile = trajfn+"_"+calc_suffix+"_"+which_traj_str+'_refsub.traj'
self.subtract_reference_energies(intraj,intrajfile)
intraj.close()
intraj = Trajectory(intrajfile)
outtrajfile = trajfn+"_"+output_traj+'_refsub.traj'
self.subtract_reference_energies(outtraj,outtrajfile)
outtraj.close()
outtraj = Trajectory(outtrajfile)
# Finally, plot comparison
clabel = 'RMS Force component deviation (eV/Ang)'
xlabel = 'Trajectory Energy (eV)'
ylabel = 'Calculator Energy (eV)'
compare_traj_to_traj(intraj,outtraj,self.plotfunc,self.plotfile,xlabel,ylabel,clabel)
if self.cleanup: # optional cleanup
os.remove(intrajfile)
os.remove(outtrajfile)
[docs] def subtract_reference_energies(self,trajin,trajout_file):
"""Subtract reference energies from a trajectory to just get energy above reference zero"""
from esteem.drivers import get_solu_solv_names
from esteem.trajectories import targstr
from esteem.tasks.clusters import get_ref_mol_energy
ref_solu, ref_solv = get_solu_solv_names(self.seed)
if ref_solv=="NO_SOLVENT_FOUND":
ref_solv = None
trajout = Trajectory(trajout_file,'w')
targ = 0
if targ==0:
ref_solu_t = ref_solu
else:
ref_solu_t = f'{ref_solu}_{targstr(targ)}'
ref_mol_dir = self.ref_mol_dir.replace("{targ}",targstr(targ))
ref_solu_dir = f'../{ref_mol_dir}'
ref_mol_dir = self.ref_mol_dir.replace("{targ}",targstr(0))
ref_solv_dir = f'../{ref_mol_dir}'
calc_params = {'calc_seed': self.calc_seed,
'calc_suffix': self.calc_suffix,
'calc_dir_suffix': self.calc_dir_suffix,
'calc_prefix': f'../{self.calc_prefix}',
'target': self.target}
# Read in Reference E, f, p
if ref_solv is not None:
ref_mol_xyz = f'{ref_solv_dir}/is_opt_{ref_solv}/{ref_solv}.xyz'
solv_energy,solv_model = get_ref_mol_energy(self.wrapper,ref_solv,ref_solv,calc_params,ref_mol_xyz,ref_solv_dir)
if isinstance(solv_energy,np.ndarray):
solv_energy = np.mean(solv_energy)
print('# Solvent reference energy: ',solv_energy)
ref_mol_xyz = f'{ref_solu_dir}/is_opt_{ref_solv}/{ref_solu}.xyz'
else:
ref_mol_xyz = f'{ref_solu_dir}/opt/{ref_solu}.xyz'
solv_energy = 0.0
solu_energy,solu_model = get_ref_mol_energy(self.wrapper,ref_solu,ref_solv,calc_params,ref_mol_xyz,ref_solu_dir)
if isinstance(solu_energy,np.ndarray):
solu_energy = np.mean(solu_energy)
print('# Solute reference energy: ',solu_energy)
for i,frame in enumerate(trajin):
n = len(frame)-len(solu_model)
if n>0:
n = int(n/len(solv_model))
e = frame.get_potential_energy()
e = e - solu_energy - n*solv_energy
frame.calc.results["energy"] = e
trajout.write(frame,**frame.calc.results)
trajout.close()
# Generate default arguments and return as parser
def make_parser(self):
import argparse
main_help = ('ML_Testing.py: Test an ML-based Calculator by comparing it to trajectories.')
epi_help = ('')
parser = argparse.ArgumentParser(description=main_help,epilog=epi_help)
parser.add_argument('--seed','-s',type=str,help='Base name stem for the calculation (often the name of the molecule)')
parser.add_argument('--calc_seed','-Z',default=None,type=str,help='Seed name for the calculator')
parser.add_argument('--calc_suffix','-S',default="",type=str,help='Suffix for the calculator (often specifies ML hyperparameters)')
parser.add_argument('--calc_dir_suffix','-D',default=None,type=str,help='Suffix for the calculator (often specifies ML hyperparameters)')
parser.add_argument('--calc_prefix','-P',default="",type=str,help='Prefix for the calculator (often specifies directory)')
parser.add_argument('--target','-t',default=0,type=int,help='Excited state index, zero for ground state')
parser.add_argument('--output_traj','-o',default=None,type=str,help='Filename to which to write the calculated trajectory')
parser.add_argument('--plotfile','-p',default=None,nargs='?',const="TkAgg",type=str,help='Image file to which to write comparison plot')
parser.add_argument('--plotfunc','-F',default=None,help='Function for plotting')
parser.add_argument('--cleanup','-C',default=True,type=bool,help='Remove reference-corrected test and merged trajectory files after finishing')
parser.add_argument('--traj_prefix','-Q',default="",type=str,help='Prefix for the trajectory files being tested')
parser.add_argument('--traj_suffix','-T',default="training",type=str,help='Suffix for the trajectory files being tested')
parser.add_argument('--ntraj','-n',default=1,type=int,help='How many total trajectories (A,B,C...) with this naming are present')
parser.add_argument('--which_trajs','-w',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be trained against')
parser.add_argument('--which_trajs_valid','-v',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be validated against')
parser.add_argument('--which_trajs_test','-u',default=None,type=str,help='Which trajectories (A,B,C...) with this naming are to be tested against')
parser.add_argument('--traj_links','-L',default=None,type=dict,help='Targets for links to create for training trajectories')
parser.add_argument('--traj_links_valid','-V',default=None,type=dict,help='Targets for links to create for validation trajectories')
parser.add_argument('--traj_links_test','-U',default=None,type=dict,help='Targets for links to create for testing trajectories')
parser.add_argument('--ref_mol_seed_dict','-z',default={},type=dict,help='Dictionary of seeds, for trajectory sets with varying seed names')
parser.add_argument('--ref_mol_dir','-r',default=None,type=str,help='Location of output of solutes run from which to find reference energies')
return parser
def validate_args(args):
default_args = make_parser().parse_args(['--seed','a'])
for arg in vars(args):
if arg not in default_args:
raise Exception(f"Unrecognised argument '{arg}'")
# # Comparison of results from trajectory to calculator
# In[ ]:
import numpy as np
from esteem.trajectories import compare_traj_to_traj,atom_energy
[docs]def compare_wrapper_to_traj(wrapper,calc_params,trajin,trajout_file):
"""Compare the energy and force predictions of a calculator to results in an existing trajectory"""
trajout = Trajectory(trajout_file,'w')
# Loop over the frames in the input trajectory
for i,frame in enumerate(trajin):
# Read in total energy and forces from trajectory
e_traj = frame.get_potential_energy()
f_traj = frame.get_forces()
d_traj = frame.get_dipole_moment()
err=False
seed=f"{calc_params['calc_seed']}{i:4d}"
e_calc, f_calc, d_calc, calc_ml = wrapper.singlepoint(frame,seed,calc_params,
forces=True,dipole=True)
# Calculate RMS and Max force errors
rms_fd = np.sqrt(np.mean((f_traj-f_calc)**2))
max_fd = np.max(np.sqrt((f_traj-f_calc)**2))
rms_dd = np.sqrt(np.mean((d_traj-d_calc)**2))
# Print header for columns
if (i==0):
print('#Idx E_traj (eV) E_calc (eV) E_diff (eV) RMS_fd MAX_fd RMS_dd')
if isinstance(calc_params["calc_suffix"],dict):
print(f'{i:4d} {e_traj:12.5f} {np.mean(e_calc):12.5f} {np.std(e_calc):8.5f} {np.mean(e_traj-e_calc):8.5f} {np.mean(np.abs(e_traj-e_calc)):8.5f} {rms_fd:8.5f} {max_fd:8.5f}')
else:
print('%4d %12.5f %12.5f %12.8f %8.5f %8.5f %8.5f' %
(i, e_traj, e_calc, e_traj-e_calc, rms_fd,max_fd,rms_dd))
# Assemble dictionary of properties to write to the output trajectory
kw = {'dipole': d_calc,
#'charges': q_calc,
'energy': e_calc,
'forces': f_calc}
# Write to trajectory
trajout.write(frame,**kw)
trajout.close()
# # Command-line driver
# In[ ]:
def get_parser():
mltest = MLTestingTask()
return mltest.make_parser()
if __name__ == '__main__':
from esteem import wrappers
mltest = MLTestingTask()
# Parse command line values
args = mltrain.make_parser().parse_args()
for arg in vars(args):
setattr(mltrain,arg,getattr(args,arg))
print('#',args)
mltest.wrapper = wrappers.amp.AMPWrapper()
# Run main program
mltest.run()