QMCPACK
test_DescentEngine.cpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2019 QMCPACK developers.
6 //
7 // File developed by: Leon Otis, leon_otis@berkeley.edu University, University of California Berkeley
8 // Ye Luo, yeluo@anl.gov, Argonne National Laboratory
9 //
10 // File created by: Leon Otis, leon_otis@berkeley.edu University, University of California Berkeley
11 //////////////////////////////////////////////////////////////////////////////////////
12 
13 #include "catch.hpp"
14 
15 #include "OhmmsData/Libxml2Doc.h"
17 #include "VariableSet.h"
18 #include "Configuration.h"
19 #include "Message/Communicate.h"
20 
21 namespace qmcplusplus
22 {
23 
26 
27 #if !defined(MIXED_PRECISION)
28 ///This provides a basic test of the descent engine's parameter update algorithm
29 TEST_CASE("DescentEngine RMSprop update", "[drivers][descent]")
30 {
32 
33 
34  const std::string engine_input("<tmp> </tmp>");
35 
37  bool okay = doc.parseFromString(engine_input);
38  REQUIRE(okay);
39 
40  xmlNodePtr fakeXML = doc.getRoot();
41 
42  std::unique_ptr<DescentEngine> descentEngineObj = std::make_unique<DescentEngine>(c, fakeXML);
43 
44  optimize::VariableSet myVars;
45 
46  //Two fake parameters are specified
47  optimize::VariableSet::real_type first_param(1.0);
48  optimize::VariableSet::real_type second_param(-2.0);
49 
50  myVars.insert("first", first_param);
51  myVars.insert("second", second_param);
52 
53  std::vector<ValueType> LDerivs;
54 
55  //Corresponding fake derivatives are specified and given to the engine
56  ValueType first_deriv = 5;
57  ValueType second_deriv = 1;
58 
59  LDerivs.push_back(first_deriv);
60  LDerivs.push_back(second_deriv);
61 
62  descentEngineObj->setDerivs(LDerivs);
63 
64  descentEngineObj->setupUpdate(myVars);
65 
66  descentEngineObj->storeDerivRecord();
67  descentEngineObj->updateParameters();
68 
69  std::vector<ValueType> results = descentEngineObj->retrieveNewParams();
70 
71  app_log() << "Descent engine test of parameter update" << std::endl;
72  app_log() << "First parameter: " << results[0] << std::endl;
73  app_log() << "Second parameter: " << results[1] << std::endl;
74 
75  //The engine should update the parameters using the generic default step size of .001 and obtain these values.
76  CHECK(std::real(results[0]) == Approx(.995));
77  CHECK(std::real(results[1]) == Approx(-2.001));
78 
79  //Provide fake data to test mpi_unbiased_ratio_of_means
80  int n = 2;
81  ValueType mean = 0;
82  ValueType variance = 0;
83  ValueType stdErr = 0;
84 
85  std::vector<ValueType> weights;
86  weights.push_back(1.0);
87  weights.push_back(1.0);
88  std::vector<ValueType> numerSamples;
89  numerSamples.push_back(-2.0);
90  numerSamples.push_back(-2.0);
91  std::vector<ValueType> denomSamples;
92  denomSamples.push_back(1.0);
93  denomSamples.push_back(1.0);
94 
95  descentEngineObj->mpi_unbiased_ratio_of_means(n, weights, numerSamples, denomSamples, mean, variance, stdErr);
96  app_log() << "Descent engine test of mpi_unbiased_ratio_of_means" << std::endl;
97  app_log() << "Mean: " << mean << std::endl;
98  app_log() << "Variance: " << variance << std::endl;
99  app_log() << "Standard Error: " << stdErr << std::endl;
100 
101  //mpi_unbiased_ratio_of_means should calculate the mean, variance, and standard error and obtain the values below
102  CHECK(std::real(mean) == Approx(-2.0));
103  CHECK(std::real(variance) == Approx(0.0));
104  CHECK(std::real(stdErr) == Approx(0.0));
105 }
106 #endif
107 } // namespace qmcplusplus
class that handles xmlDoc
Definition: Libxml2Doc.h:76
qmcplusplus::QMCTraits::FullPrecValueType FullPrecValueType
QMCTraits::RealType real
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
std::ostream & app_log()
Definition: OutputManager.h:65
TEST_CASE("complex_helper", "[type_traits]")
xmlNodePtr getRoot()
Definition: Libxml2Doc.h:88
Communicate * Controller
Global Communicator for a process.
Definition: Communicate.cpp:35
Wrapping information on parallelism.
Definition: Communicate.h:68
QTFull::ValueType FullPrecValueType
Definition: Configuration.h:67
QTBase::ValueType ValueType
Definition: Configuration.h:60
REQUIRE(std::filesystem::exists(filename))
class to handle a set of variables that can be modified during optimizations
Definition: VariableSet.h:49
bool parseFromString(const std::string_view data)
Definition: Libxml2Doc.cpp:204
CHECK(log_values[0]==ComplexApprox(std::complex< double >{ 5.603777579195571, -6.1586603331188225 }))
qmcplusplus::QMCTraits::RealType real_type
Definition: VariableSet.h:51
LatticeGaussianProduct::ValueType ValueType
ACC::value_type mean(const ACC &ac)
Definition: accumulators.h:147