27 #if !defined(MIXED_PRECISION) 29 TEST_CASE(
"DescentEngine RMSprop update",
"[drivers][descent]")
34 const std::string engine_input(
"<tmp> </tmp>");
42 std::unique_ptr<DescentEngine> descentEngineObj = std::make_unique<DescentEngine>(c, fakeXML);
50 myVars.insert(
"first", first_param);
51 myVars.insert(
"second", second_param);
53 std::vector<ValueType> LDerivs;
59 LDerivs.push_back(first_deriv);
60 LDerivs.push_back(second_deriv);
62 descentEngineObj->setDerivs(LDerivs);
64 descentEngineObj->setupUpdate(myVars);
66 descentEngineObj->storeDerivRecord();
67 descentEngineObj->updateParameters();
69 std::vector<ValueType> results = descentEngineObj->retrieveNewParams();
71 app_log() <<
"Descent engine test of parameter update" << std::endl;
72 app_log() <<
"First parameter: " << results[0] << std::endl;
73 app_log() <<
"Second parameter: " << results[1] << std::endl;
85 std::vector<ValueType> weights;
86 weights.push_back(1.0);
87 weights.push_back(1.0);
88 std::vector<ValueType> numerSamples;
89 numerSamples.push_back(-2.0);
90 numerSamples.push_back(-2.0);
91 std::vector<ValueType> denomSamples;
92 denomSamples.push_back(1.0);
93 denomSamples.push_back(1.0);
95 descentEngineObj->mpi_unbiased_ratio_of_means(
n, weights, numerSamples, denomSamples,
mean, variance, stdErr);
96 app_log() <<
"Descent engine test of mpi_unbiased_ratio_of_means" << std::endl;
98 app_log() <<
"Variance: " << variance << std::endl;
99 app_log() <<
"Standard Error: " << stdErr << std::endl;
class that handles xmlDoc
qmcplusplus::QMCTraits::FullPrecValueType FullPrecValueType
helper functions for EinsplineSetBuilder
TEST_CASE("complex_helper", "[type_traits]")
Communicate * Controller
Global Communicator for a process.
Wrapping information on parallelism.
QTFull::ValueType FullPrecValueType
QTBase::ValueType ValueType
REQUIRE(std::filesystem::exists(filename))
class to handle a set of variables that can be modified during optimizations
bool parseFromString(const std::string_view data)
CHECK(log_values[0]==ComplexApprox(std::complex< double >{ 5.603777579195571, -6.1586603331188225 }))
qmcplusplus::QMCTraits::RealType real_type
LatticeGaussianProduct::ValueType ValueType
ACC::value_type mean(const ACC &ac)