00001 //---------------------------------------------------------------------------- 00002 /** @file HexAbSearch.cpp 00003 */ 00004 //---------------------------------------------------------------------------- 00005 00006 #include "SgSystem.h" 00007 00008 #include "Hex.hpp" 00009 #include "Time.hpp" 00010 #include "HexAbSearch.hpp" 00011 #include "HexBoard.hpp" 00012 #include "EndgameUtils.hpp" 00013 #include "SequenceHash.hpp" 00014 00015 using namespace benzene; 00016 00017 //---------------------------------------------------------------------------- 00018 00019 /** Local utilities. */ 00020 namespace 00021 { 00022 00023 /** Dump state info so the gui can display progress. */ 00024 void DumpGuiFx(std::vector<HexMoveValue> finished, int num_to_explore, 00025 std::vector<HexPoint> pv, HexColor color) 00026 { 00027 std::ostringstream os; 00028 os << "gogui-gfx:\n"; 00029 os << "ab\n"; 00030 os << "VAR"; 00031 for (std::size_t i=0; i<pv.size(); ++i) 00032 { 00033 os << " " << ((color == BLACK) ? "B" : "W"); 00034 os << " " << pv[i]; 00035 color = !color; 00036 } 00037 os << "\n"; 00038 os << "LABEL"; 00039 for (std::size_t i=0; i<finished.size(); ++i) 00040 { 00041 os << " " << finished[i].point(); 00042 HexEval value = finished[i].value(); 00043 if (HexEvalUtil::IsWin(value)) 00044 os << " W"; 00045 else if (HexEvalUtil::IsLoss(value)) 00046 os << " L"; 00047 else 00048 os << " " << std::fixed << std::setprecision(2) << value; 00049 } 00050 os << "\n"; 00051 os << "TEXT"; 00052 os << " " << finished.size() << "/" << num_to_explore; 00053 os << "\n"; 00054 os << "\n"; 00055 std::cout << os.str(); 00056 std::cout.flush(); 00057 } 00058 00059 std::string DumpPV(HexEval value, const std::vector<HexPoint>& pv) 00060 { 00061 std::ostringstream os; 00062 os << "PV: [" << std::fixed << std::setprecision(4) << value << "]"; 00063 for (std::size_t i=0; i<pv.size(); ++i) 00064 os << " " << pv[i]; 00065 return os.str(); 00066 } 00067 00068 //---------------------------------------------------------------------------- 00069 00070 } // anonymous namespace 00071 00072 //---------------------------------------------------------------------------- 00073 00074 HexAbSearch::HexAbSearch() 00075 : m_brd(0), 00076 m_tt(0), 00077 m_use_guifx(false) 00078 { 00079 } 00080 00081 HexAbSearch::~HexAbSearch() 00082 { 00083 } 00084 00085 //---------------------------------------------------------------------------- 00086 00087 void HexAbSearch::EnteredNewState() {} 00088 00089 void HexAbSearch::OnStartSearch() {} 00090 00091 void HexAbSearch::OnSearchComplete() {} 00092 00093 void HexAbSearch::AfterStateSearched() {} 00094 00095 HexEval HexAbSearch::Search(HexBoard& brd, HexColor color, 00096 const std::vector<int>& plywidth, 00097 const std::vector<int>& depths_to_search, 00098 int timelimit, 00099 std::vector<HexPoint>& outPV) 00100 { 00101 UNUSED(timelimit); 00102 00103 double start = Time::Get(); 00104 00105 m_brd = &brd; 00106 m_toplay = color; 00107 m_statistics = Statistics(); 00108 00109 OnStartSearch(); 00110 00111 std::vector<HexMoveValue> outEval; 00112 double outValue = -EVAL_INFINITY; 00113 outPV.clear(); 00114 outPV.push_back(INVALID_POINT); 00115 00116 m_aborted = false; 00117 for (std::size_t d=0; !m_aborted && d < depths_to_search.size(); ++d) 00118 { 00119 int depth = depths_to_search[d]; 00120 LogInfo() << "---- Depth " << depth << " ----" << '\n'; 00121 00122 double beganAt = Time::Get(); 00123 00124 m_eval.clear(); 00125 m_current_depth = 0; 00126 m_sequence.clear(); 00127 std::vector<HexPoint> thisPV; 00128 00129 double thisValue = SearchState(plywidth, depth, IMMEDIATE_LOSS, 00130 IMMEDIATE_WIN, thisPV); 00131 00132 double finishedAt = Time::Get(); 00133 00134 // copy result only if search was not aborted 00135 if (!m_aborted) 00136 { 00137 outPV = thisPV; 00138 outValue = thisValue; 00139 outEval = m_eval; 00140 00141 m_statistics.value = thisValue; 00142 m_statistics.pv = thisPV; 00143 00144 LogInfo() 00145 << DumpPV(thisValue, thisPV) << '\n' 00146 << "Time: " << std::fixed << std::setprecision(4) 00147 << (finishedAt - beganAt) << '\n'; 00148 } 00149 else 00150 { 00151 LogInfo() << "Throwing away current iteration..." 00152 << '\n'; 00153 } 00154 } 00155 00156 OnSearchComplete(); 00157 00158 double end = Time::Get(); 00159 m_statistics.elapsed_time = end - start; 00160 00161 // Copy the root evaluations back into m_eval; these will be printed 00162 // when DumpStats() is called. 00163 m_eval = outEval; 00164 00165 return outValue; 00166 } 00167 00168 //---------------------------------------------------------------------------- 00169 00170 HexEval HexAbSearch::CheckTerminalState() 00171 { 00172 if (EndgameUtils::IsWonGame(*m_brd, m_toplay)) 00173 return IMMEDIATE_WIN - m_current_depth; 00174 00175 if (EndgameUtils::IsLostGame(*m_brd, m_toplay)) 00176 return IMMEDIATE_LOSS + m_current_depth; 00177 00178 return 0; 00179 } 00180 00181 bool HexAbSearch::CheckAbort() 00182 { 00183 if (SgUserAbort()) 00184 { 00185 LogInfo() << "HexAbSearch::CheckAbort(): Abort flag!" << '\n'; 00186 m_aborted = true; 00187 return true; 00188 } 00189 00190 // @todo CHECK TIMELIMIT 00191 return false; 00192 } 00193 00194 HexEval HexAbSearch::SearchState(const std::vector<int>& plywidth, 00195 int depth, HexEval alpha, HexEval beta, 00196 std::vector<HexPoint>& pv) 00197 { 00198 HexAssert(m_current_depth + depth <= (int)plywidth.size()); 00199 00200 if (CheckAbort()) 00201 return -EVAL_INFINITY; 00202 00203 m_statistics.numstates++; 00204 pv.clear(); 00205 00206 // modify beta so that we abort on an immediate win 00207 beta = std::min(beta, IMMEDIATE_WIN - (m_current_depth+1)); 00208 00209 HexEval old_alpha = alpha; 00210 HexEval old_beta = beta; 00211 00212 EnteredNewState(); 00213 00214 // 00215 // Check for terminal states 00216 // 00217 { 00218 HexEval value = CheckTerminalState(); 00219 if (value != 0) { 00220 m_statistics.numterminal++; 00221 LogFine() << "Terminal: " << value << '\n'; 00222 return value; 00223 } 00224 } 00225 00226 // 00227 // Evaluate if a leaf 00228 // 00229 if (depth == 0) { 00230 m_statistics.numleafs++; 00231 HexEval value = Evaluate(); 00232 return value; 00233 } 00234 00235 // 00236 // Check for transposition 00237 // 00238 std::string space(3*m_current_depth, ' '); 00239 00240 m_tt_info_available = false; 00241 m_tt_bestmove = INVALID_POINT; 00242 if (m_tt) 00243 { 00244 SearchedState state; 00245 if (m_tt->Get(m_brd->GetPosition().Hash(), state)) 00246 { 00247 m_tt_info_available = true; 00248 m_tt_bestmove = state.move; 00249 00250 if (state.depth >= depth) 00251 { 00252 m_statistics.tt_hits++; 00253 00254 LogFine() << space << "--- TT HIT ---" << '\n'; 00255 00256 if (state.bound == SearchedState::LOWER_BOUND) 00257 { 00258 LogFine() << "Lower Bound" << '\n'; 00259 alpha = std::max(alpha, state.score); 00260 } 00261 else if (state.bound == SearchedState::UPPER_BOUND) 00262 { 00263 LogFine() << "Upper Bound" << '\n'; 00264 beta = std::min(beta, state.score); 00265 } 00266 else if (state.bound == SearchedState::ACCURATE) 00267 { 00268 LogFine() << "Accurate" << '\n'; 00269 alpha = beta = state.score; 00270 } 00271 00272 LogFine() << "new (alpha, beta): (" << alpha 00273 << ", " << beta << ")" << '\n'; 00274 00275 if (alpha >= beta) 00276 { 00277 m_statistics.tt_cuts++; 00278 00279 pv.clear(); 00280 pv.push_back(state.move); 00281 00282 return state.score; 00283 } 00284 } 00285 } 00286 } 00287 00288 m_statistics.numinternal++; 00289 00290 std::vector<HexPoint> moves; 00291 GenerateMoves(moves); 00292 HexAssert(moves.size()); 00293 00294 int curwidth = std::min(plywidth[m_current_depth], (int)moves.size()); 00295 m_statistics.mustplay_branches += (int)moves.size(); 00296 m_statistics.total_branches += curwidth; 00297 00298 HexPoint bestmove = INVALID_POINT; 00299 HexEval bestvalue = -EVAL_INFINITY; 00300 00301 for (int m = 0; !m_aborted && m < curwidth; ++m) 00302 { 00303 m_statistics.visited_branches++; 00304 LogFine() << space 00305 << (m+1) << "/" 00306 << curwidth << ": (" 00307 << m_toplay << ", " << moves[m] << ")" 00308 << ", (" << alpha << ", " << beta << ")" 00309 << '\n'; 00310 00311 ExecuteMove(moves[m]); 00312 m_current_depth++; 00313 m_sequence.push_back(moves[m]); 00314 m_toplay = !m_toplay; 00315 00316 std::vector<HexPoint> cv; 00317 HexEval value = -SearchState(plywidth, depth-1, -beta, -alpha, cv); 00318 00319 m_toplay = !m_toplay; 00320 m_sequence.pop_back(); 00321 m_current_depth--; 00322 UndoMove(moves[m]); 00323 00324 if (value > bestvalue) 00325 { 00326 bestmove = moves[m]; 00327 bestvalue = value; 00328 00329 // compute new principal variation 00330 pv.clear(); 00331 pv.push_back(bestmove); 00332 pv.insert(pv.end(), cv.begin(), cv.end()); 00333 00334 LogFine() << space << "--- New best: " << value 00335 << " PV: " << HexPointUtil::ToString(pv) << " ---\n"; 00336 } 00337 00338 // store root move evaluations and output progress to gui 00339 if (m_current_depth == 0) 00340 { 00341 m_eval.push_back(HexMoveValue(moves[m], value)); 00342 if (m_use_guifx) 00343 DumpGuiFx(m_eval, curwidth, pv, m_toplay); 00344 } 00345 00346 if (value >= alpha) 00347 alpha = value; 00348 00349 if (alpha >= beta) 00350 { 00351 LogFine() << space << "--- Cutoff ---" << '\n'; 00352 m_statistics.cuts++; 00353 break; 00354 } 00355 } 00356 00357 if (m_aborted) 00358 return -EVAL_INFINITY; 00359 00360 // 00361 // Store in tt 00362 // 00363 HexAssert(bestmove != INVALID_POINT); 00364 if (m_tt) 00365 { 00366 SearchedState::Bound bound = SearchedState::ACCURATE; 00367 if (bestvalue <= old_alpha) bound = SearchedState::UPPER_BOUND; 00368 if (bestvalue >= old_beta) bound = SearchedState::LOWER_BOUND; 00369 SearchedState ss(m_brd->GetPosition().Hash(), depth, bound, 00370 bestvalue, bestmove); 00371 m_tt->Put(m_brd->GetPosition().Hash(), ss); 00372 } 00373 00374 AfterStateSearched(); 00375 00376 return bestvalue; 00377 } 00378 00379 std::string HexAbSearch::DumpStats() 00380 { 00381 std::ostringstream os; 00382 os << m_statistics.Dump() << '\n'; 00383 00384 std::vector<HexMoveValue> root_evals(m_eval); 00385 stable_sort(root_evals.begin(), root_evals.end(), 00386 std::greater<HexMoveValue>()); 00387 00388 os << '\n'; 00389 std::size_t num = 10; 00390 for (std::size_t i=0; i<num && i<root_evals.size(); ++i) { 00391 if (i && i%5==0) 00392 os << '\n'; 00393 os << "(" 00394 << root_evals[i].point() << ", " 00395 << std::fixed << std::setprecision(3) << root_evals[i].value() 00396 << ") "; 00397 } 00398 os << '\n'; 00399 00400 return os.str(); 00401 } 00402 00403 std::string HexAbSearch::Statistics::Dump() const 00404 { 00405 std::ostringstream os; 00406 os << '\n' 00407 << " Leaf Nodes: " << numleafs << '\n' 00408 << " Terminal Nodes: " << numterminal << '\n' 00409 << " Internal Nodes: " << numinternal << '\n' 00410 << " Total Nodes: " << numstates << '\n' 00411 << " TT Hits: " << tt_hits << '\n' 00412 << " TT Cuts: " << tt_cuts << '\n' 00413 << "Avg. Mustplay Size: " << std::setprecision(4) 00414 << (double)mustplay_branches / numinternal << '\n' 00415 << "Avg. Branch Factor: " << std::setprecision(4) 00416 << (double)total_branches / numinternal << '\n' 00417 << " Avg. To Cut: " << std::setprecision(4) 00418 << (double)visited_branches / numinternal << '\n' 00419 << " Nodes/Sec: " << std::setprecision(4) 00420 << (numstates/elapsed_time) << '\n' 00421 << " Elapsed Time: " << std::setprecision(4) 00422 << elapsed_time << "s" << '\n' 00423 << '\n' 00424 << DumpPV(value, pv); 00425 return os.str(); 00426 } 00427 00428 //----------------------------------------------------------------------------