00001
00002
00003
00004
00005
00006 #include "SgSystem.h"
00007
00008 #include "Hex.hpp"
00009 #include "Time.hpp"
00010 #include "HexAbSearch.hpp"
00011 #include "HexBoard.hpp"
00012 #include "EndgameUtils.hpp"
00013 #include "SequenceHash.hpp"
00014
00015 using namespace benzene;
00016
00017
00018
00019
00020 namespace
00021 {
00022
00023
00024 void DumpGuiFx(std::vector<HexMoveValue> finished, int num_to_explore,
00025 std::vector<HexPoint> pv, HexColor color)
00026 {
00027 std::ostringstream os;
00028 os << "gogui-gfx:\n";
00029 os << "ab\n";
00030 os << "VAR";
00031 for (std::size_t i=0; i<pv.size(); ++i)
00032 {
00033 os << " " << ((color == BLACK) ? "B" : "W");
00034 os << " " << pv[i];
00035 color = !color;
00036 }
00037 os << "\n";
00038 os << "LABEL";
00039 for (std::size_t i=0; i<finished.size(); ++i)
00040 {
00041 os << " " << finished[i].point();
00042 HexEval value = finished[i].value();
00043 if (HexEvalUtil::IsWin(value))
00044 os << " W";
00045 else if (HexEvalUtil::IsLoss(value))
00046 os << " L";
00047 else
00048 os << " " << std::fixed << std::setprecision(2) << value;
00049 }
00050 os << "\n";
00051 os << "TEXT";
00052 os << " " << finished.size() << "/" << num_to_explore;
00053 os << "\n";
00054 os << "\n";
00055 std::cout << os.str();
00056 std::cout.flush();
00057 }
00058
00059 std::string DumpPV(HexEval value, const std::vector<HexPoint>& pv)
00060 {
00061 std::ostringstream os;
00062 os << "PV: [" << std::fixed << std::setprecision(4) << value << "]";
00063 for (std::size_t i=0; i<pv.size(); ++i)
00064 os << " " << pv[i];
00065 return os.str();
00066 }
00067
00068
00069
00070 }
00071
00072
00073
00074 HexAbSearch::HexAbSearch()
00075 : m_brd(0),
00076 m_tt(0),
00077 m_use_guifx(false)
00078 {
00079 }
00080
00081 HexAbSearch::~HexAbSearch()
00082 {
00083 }
00084
00085
00086
00087 void HexAbSearch::EnteredNewState() {}
00088
00089 void HexAbSearch::OnStartSearch() {}
00090
00091 void HexAbSearch::OnSearchComplete() {}
00092
00093 void HexAbSearch::AfterStateSearched() {}
00094
00095 HexEval HexAbSearch::Search(HexBoard& brd, HexColor color,
00096 const std::vector<int>& plywidth,
00097 const std::vector<int>& depths_to_search,
00098 int timelimit,
00099 std::vector<HexPoint>& outPV)
00100 {
00101 UNUSED(timelimit);
00102
00103 double start = Time::Get();
00104
00105 m_brd = &brd;
00106 m_toplay = color;
00107 m_statistics = Statistics();
00108
00109 OnStartSearch();
00110
00111 std::vector<HexMoveValue> outEval;
00112 double outValue = -EVAL_INFINITY;
00113 outPV.clear();
00114 outPV.push_back(INVALID_POINT);
00115
00116 m_aborted = false;
00117 for (std::size_t d=0; !m_aborted && d < depths_to_search.size(); ++d)
00118 {
00119 int depth = depths_to_search[d];
00120 LogInfo() << "---- Depth " << depth << " ----" << '\n';
00121
00122 double beganAt = Time::Get();
00123
00124 m_eval.clear();
00125 m_current_depth = 0;
00126 m_sequence.clear();
00127 std::vector<HexPoint> thisPV;
00128
00129 double thisValue = SearchState(plywidth, depth, IMMEDIATE_LOSS,
00130 IMMEDIATE_WIN, thisPV);
00131
00132 double finishedAt = Time::Get();
00133
00134
00135 if (!m_aborted)
00136 {
00137 outPV = thisPV;
00138 outValue = thisValue;
00139 outEval = m_eval;
00140
00141 m_statistics.value = thisValue;
00142 m_statistics.pv = thisPV;
00143
00144 LogInfo()
00145 << DumpPV(thisValue, thisPV) << '\n'
00146 << "Time: " << std::fixed << std::setprecision(4)
00147 << (finishedAt - beganAt) << '\n';
00148 }
00149 else
00150 {
00151 LogInfo() << "Throwing away current iteration..."
00152 << '\n';
00153 }
00154 }
00155
00156 OnSearchComplete();
00157
00158 double end = Time::Get();
00159 m_statistics.elapsed_time = end - start;
00160
00161
00162
00163 m_eval = outEval;
00164
00165 return outValue;
00166 }
00167
00168
00169
00170 HexEval HexAbSearch::CheckTerminalState()
00171 {
00172 if (EndgameUtils::IsWonGame(*m_brd, m_toplay))
00173 return IMMEDIATE_WIN - m_current_depth;
00174
00175 if (EndgameUtils::IsLostGame(*m_brd, m_toplay))
00176 return IMMEDIATE_LOSS + m_current_depth;
00177
00178 return 0;
00179 }
00180
00181 bool HexAbSearch::CheckAbort()
00182 {
00183 if (SgUserAbort())
00184 {
00185 LogInfo() << "HexAbSearch::CheckAbort(): Abort flag!" << '\n';
00186 m_aborted = true;
00187 return true;
00188 }
00189
00190
00191 return false;
00192 }
00193
00194 HexEval HexAbSearch::SearchState(const std::vector<int>& plywidth,
00195 int depth, HexEval alpha, HexEval beta,
00196 std::vector<HexPoint>& pv)
00197 {
00198 HexAssert(m_current_depth + depth <= (int)plywidth.size());
00199
00200 if (CheckAbort())
00201 return -EVAL_INFINITY;
00202
00203 m_statistics.numstates++;
00204 pv.clear();
00205
00206
00207 beta = std::min(beta, IMMEDIATE_WIN - (m_current_depth+1));
00208
00209 HexEval old_alpha = alpha;
00210 HexEval old_beta = beta;
00211
00212 EnteredNewState();
00213
00214
00215
00216
00217 {
00218 HexEval value = CheckTerminalState();
00219 if (value != 0) {
00220 m_statistics.numterminal++;
00221 LogFine() << "Terminal: " << value << '\n';
00222 return value;
00223 }
00224 }
00225
00226
00227
00228
00229 if (depth == 0) {
00230 m_statistics.numleafs++;
00231 HexEval value = Evaluate();
00232 return value;
00233 }
00234
00235
00236
00237
00238 std::string space(3*m_current_depth, ' ');
00239
00240 m_tt_info_available = false;
00241 m_tt_bestmove = INVALID_POINT;
00242 if (m_tt)
00243 {
00244 SearchedState state;
00245 if (m_tt->Get(m_brd->GetPosition().Hash(), state))
00246 {
00247 m_tt_info_available = true;
00248 m_tt_bestmove = state.move;
00249
00250 if (state.depth >= depth)
00251 {
00252 m_statistics.tt_hits++;
00253
00254 LogFine() << space << "--- TT HIT ---" << '\n';
00255
00256 if (state.bound == SearchedState::LOWER_BOUND)
00257 {
00258 LogFine() << "Lower Bound" << '\n';
00259 alpha = std::max(alpha, state.score);
00260 }
00261 else if (state.bound == SearchedState::UPPER_BOUND)
00262 {
00263 LogFine() << "Upper Bound" << '\n';
00264 beta = std::min(beta, state.score);
00265 }
00266 else if (state.bound == SearchedState::ACCURATE)
00267 {
00268 LogFine() << "Accurate" << '\n';
00269 alpha = beta = state.score;
00270 }
00271
00272 LogFine() << "new (alpha, beta): (" << alpha
00273 << ", " << beta << ")" << '\n';
00274
00275 if (alpha >= beta)
00276 {
00277 m_statistics.tt_cuts++;
00278
00279 pv.clear();
00280 pv.push_back(state.move);
00281
00282 return state.score;
00283 }
00284 }
00285 }
00286 }
00287
00288 m_statistics.numinternal++;
00289
00290 std::vector<HexPoint> moves;
00291 GenerateMoves(moves);
00292 HexAssert(moves.size());
00293
00294 int curwidth = std::min(plywidth[m_current_depth], (int)moves.size());
00295 m_statistics.mustplay_branches += (int)moves.size();
00296 m_statistics.total_branches += curwidth;
00297
00298 HexPoint bestmove = INVALID_POINT;
00299 HexEval bestvalue = -EVAL_INFINITY;
00300
00301 for (int m = 0; !m_aborted && m < curwidth; ++m)
00302 {
00303 m_statistics.visited_branches++;
00304 LogFine() << space
00305 << (m+1) << "/"
00306 << curwidth << ": ("
00307 << m_toplay << ", " << moves[m] << ")"
00308 << ", (" << alpha << ", " << beta << ")"
00309 << '\n';
00310
00311 ExecuteMove(moves[m]);
00312 m_current_depth++;
00313 m_sequence.push_back(moves[m]);
00314 m_toplay = !m_toplay;
00315
00316 std::vector<HexPoint> cv;
00317 HexEval value = -SearchState(plywidth, depth-1, -beta, -alpha, cv);
00318
00319 m_toplay = !m_toplay;
00320 m_sequence.pop_back();
00321 m_current_depth--;
00322 UndoMove(moves[m]);
00323
00324 if (value > bestvalue)
00325 {
00326 bestmove = moves[m];
00327 bestvalue = value;
00328
00329
00330 pv.clear();
00331 pv.push_back(bestmove);
00332 pv.insert(pv.end(), cv.begin(), cv.end());
00333
00334 LogFine() << space << "--- New best: " << value
00335 << " PV: " << HexPointUtil::ToString(pv) << " ---\n";
00336 }
00337
00338
00339 if (m_current_depth == 0)
00340 {
00341 m_eval.push_back(HexMoveValue(moves[m], value));
00342 if (m_use_guifx)
00343 DumpGuiFx(m_eval, curwidth, pv, m_toplay);
00344 }
00345
00346 if (value >= alpha)
00347 alpha = value;
00348
00349 if (alpha >= beta)
00350 {
00351 LogFine() << space << "--- Cutoff ---" << '\n';
00352 m_statistics.cuts++;
00353 break;
00354 }
00355 }
00356
00357 if (m_aborted)
00358 return -EVAL_INFINITY;
00359
00360
00361
00362
00363 HexAssert(bestmove != INVALID_POINT);
00364 if (m_tt)
00365 {
00366 SearchedState::Bound bound = SearchedState::ACCURATE;
00367 if (bestvalue <= old_alpha) bound = SearchedState::UPPER_BOUND;
00368 if (bestvalue >= old_beta) bound = SearchedState::LOWER_BOUND;
00369 SearchedState ss(m_brd->GetPosition().Hash(), depth, bound,
00370 bestvalue, bestmove);
00371 m_tt->Put(m_brd->GetPosition().Hash(), ss);
00372 }
00373
00374 AfterStateSearched();
00375
00376 return bestvalue;
00377 }
00378
00379 std::string HexAbSearch::DumpStats()
00380 {
00381 std::ostringstream os;
00382 os << m_statistics.Dump() << '\n';
00383
00384 std::vector<HexMoveValue> root_evals(m_eval);
00385 stable_sort(root_evals.begin(), root_evals.end(),
00386 std::greater<HexMoveValue>());
00387
00388 os << '\n';
00389 std::size_t num = 10;
00390 for (std::size_t i=0; i<num && i<root_evals.size(); ++i) {
00391 if (i && i%5==0)
00392 os << '\n';
00393 os << "("
00394 << root_evals[i].point() << ", "
00395 << std::fixed << std::setprecision(3) << root_evals[i].value()
00396 << ") ";
00397 }
00398 os << '\n';
00399
00400 return os.str();
00401 }
00402
00403 std::string HexAbSearch::Statistics::Dump() const
00404 {
00405 std::ostringstream os;
00406 os << '\n'
00407 << " Leaf Nodes: " << numleafs << '\n'
00408 << " Terminal Nodes: " << numterminal << '\n'
00409 << " Internal Nodes: " << numinternal << '\n'
00410 << " Total Nodes: " << numstates << '\n'
00411 << " TT Hits: " << tt_hits << '\n'
00412 << " TT Cuts: " << tt_cuts << '\n'
00413 << "Avg. Mustplay Size: " << std::setprecision(4)
00414 << (double)mustplay_branches / numinternal << '\n'
00415 << "Avg. Branch Factor: " << std::setprecision(4)
00416 << (double)total_branches / numinternal << '\n'
00417 << " Avg. To Cut: " << std::setprecision(4)
00418 << (double)visited_branches / numinternal << '\n'
00419 << " Nodes/Sec: " << std::setprecision(4)
00420 << (numstates/elapsed_time) << '\n'
00421 << " Elapsed Time: " << std::setprecision(4)
00422 << elapsed_time << "s" << '\n'
00423 << '\n'
00424 << DumpPV(value, pv);
00425 return os.str();
00426 }
00427
00428