Main   Namespaces   Classes   Hierarchy   Annotated   Files   Compound   Global   Pages  

HexAbSearch.cpp

Go to the documentation of this file.
00001 //----------------------------------------------------------------------------
00002 /** @file HexAbSearch.cpp
00003  */
00004 //----------------------------------------------------------------------------
00005 
00006 #include "SgSystem.h"
00007 
00008 #include "Hex.hpp"
00009 #include "Time.hpp"
00010 #include "HexAbSearch.hpp"
00011 #include "HexBoard.hpp"
00012 #include "EndgameUtils.hpp"
00013 #include "SequenceHash.hpp"
00014 
00015 using namespace benzene;
00016 
00017 //----------------------------------------------------------------------------
00018 
00019 /** Local utilities. */
00020 namespace
00021 {
00022 
00023 /** Dump state info so the gui can display progress. */
00024 void DumpGuiFx(std::vector<HexMoveValue> finished, int num_to_explore,
00025                std::vector<HexPoint> pv, HexColor color)
00026 {
00027     std::ostringstream os;
00028     os << "gogui-gfx:\n";
00029     os << "ab\n";
00030     os << "VAR";
00031     for (std::size_t i=0; i<pv.size(); ++i) 
00032     {
00033         os << " " << ((color == BLACK) ? "B" : "W");
00034         os << " " << pv[i];
00035         color = !color;
00036     }
00037     os << "\n";
00038     os << "LABEL";
00039     for (std::size_t i=0; i<finished.size(); ++i) 
00040     {
00041         os << " " << finished[i].point();
00042         HexEval value = finished[i].value();
00043         if (HexEvalUtil::IsWin(value))
00044             os << " W";
00045         else if (HexEvalUtil::IsLoss(value))
00046             os << " L";
00047         else 
00048             os << " " << std::fixed << std::setprecision(2) << value;
00049     }
00050     os << "\n";
00051     os << "TEXT";
00052     os << " " << finished.size() << "/" << num_to_explore;
00053     os << "\n";
00054     os << "\n";
00055     std::cout << os.str();
00056     std::cout.flush();
00057 }
00058 
00059 std::string DumpPV(HexEval value, const std::vector<HexPoint>& pv)
00060 {
00061     std::ostringstream os;
00062     os << "PV: [" << std::fixed << std::setprecision(4) << value << "]";
00063     for (std::size_t i=0; i<pv.size(); ++i)
00064         os << " " << pv[i];
00065     return os.str();
00066 }
00067 
00068 //----------------------------------------------------------------------------
00069 
00070 } // anonymous namespace
00071 
00072 //----------------------------------------------------------------------------
00073 
00074 HexAbSearch::HexAbSearch()
00075     : m_brd(0),
00076       m_tt(0),
00077       m_use_guifx(false)
00078 {
00079 }
00080 
00081 HexAbSearch::~HexAbSearch()
00082 {
00083 }
00084 
00085 //----------------------------------------------------------------------------
00086 
00087 void HexAbSearch::EnteredNewState() {}
00088 
00089 void HexAbSearch::OnStartSearch() {}
00090     
00091 void HexAbSearch::OnSearchComplete() {}
00092 
00093 void HexAbSearch::AfterStateSearched() {}
00094 
00095 HexEval HexAbSearch::Search(HexBoard& brd, HexColor color,
00096                             const std::vector<int>& plywidth, 
00097                             const std::vector<int>& depths_to_search, 
00098                             int timelimit,
00099                             std::vector<HexPoint>& outPV)
00100 {
00101     UNUSED(timelimit);
00102 
00103     double start = Time::Get();
00104 
00105     m_brd = &brd;
00106     m_toplay = color;
00107     m_statistics = Statistics();
00108 
00109     OnStartSearch();
00110 
00111     std::vector<HexMoveValue> outEval;
00112     double outValue = -EVAL_INFINITY;
00113     outPV.clear();
00114     outPV.push_back(INVALID_POINT);
00115 
00116     m_aborted = false;
00117     for (std::size_t d=0; !m_aborted && d < depths_to_search.size(); ++d) 
00118     {
00119         int depth = depths_to_search[d];
00120         LogInfo() << "---- Depth " << depth << " ----" << '\n';
00121 
00122         double beganAt = Time::Get();
00123 
00124         m_eval.clear();
00125         m_current_depth = 0;
00126         m_sequence.clear();
00127         std::vector<HexPoint> thisPV;
00128 
00129         double thisValue = SearchState(plywidth, depth, IMMEDIATE_LOSS, 
00130                                        IMMEDIATE_WIN, thisPV);
00131 
00132         double finishedAt = Time::Get();
00133 
00134         // copy result only if search was not aborted
00135         if (!m_aborted)
00136         {
00137             outPV = thisPV;
00138             outValue = thisValue;
00139             outEval = m_eval;
00140 
00141             m_statistics.value = thisValue;
00142             m_statistics.pv = thisPV;
00143 
00144             LogInfo() 
00145                      << DumpPV(thisValue, thisPV) << '\n'
00146                      << "Time: " << std::fixed << std::setprecision(4) 
00147                      << (finishedAt - beganAt) << '\n';
00148         }
00149         else
00150         {
00151             LogInfo() << "Throwing away current iteration..."
00152                      << '\n';
00153         }
00154     }
00155     
00156     OnSearchComplete();
00157 
00158     double end = Time::Get();
00159     m_statistics.elapsed_time = end - start;
00160     
00161     // Copy the root evaluations back into m_eval; these will be printed
00162     // when DumpStats() is called.
00163     m_eval = outEval;
00164 
00165     return outValue;
00166 }
00167 
00168 //----------------------------------------------------------------------------
00169 
00170 HexEval HexAbSearch::CheckTerminalState()
00171 {
00172     if (EndgameUtils::IsWonGame(*m_brd, m_toplay))
00173         return IMMEDIATE_WIN - m_current_depth;
00174     
00175     if (EndgameUtils::IsLostGame(*m_brd, m_toplay))
00176         return IMMEDIATE_LOSS + m_current_depth;
00177 
00178     return 0;
00179 }
00180 
00181 bool HexAbSearch::CheckAbort()
00182 {
00183     if (SgUserAbort())
00184     {
00185         LogInfo() << "HexAbSearch::CheckAbort(): Abort flag!" << '\n';
00186         m_aborted = true;
00187         return true;
00188     }
00189 
00190     // @todo CHECK TIMELIMIT
00191     return false;
00192 }
00193 
00194 HexEval HexAbSearch::SearchState(const std::vector<int>& plywidth,
00195                                  int depth, HexEval alpha, HexEval beta,
00196                                  std::vector<HexPoint>& pv)
00197 {
00198     HexAssert(m_current_depth + depth <= (int)plywidth.size());
00199 
00200     if (CheckAbort()) 
00201         return -EVAL_INFINITY;
00202 
00203     m_statistics.numstates++;
00204     pv.clear();
00205 
00206     // modify beta so that we abort on an immediate win
00207     beta = std::min(beta, IMMEDIATE_WIN - (m_current_depth+1));
00208 
00209     HexEval old_alpha = alpha;
00210     HexEval old_beta = beta;
00211 
00212     EnteredNewState();
00213     
00214     //
00215     // Check for terminal states
00216     //
00217     {
00218         HexEval value = CheckTerminalState();
00219         if (value != 0) {
00220             m_statistics.numterminal++;
00221             LogFine() << "Terminal: " << value << '\n';
00222             return value;
00223         }
00224     }
00225 
00226     //
00227     // Evaluate if a leaf
00228     //
00229     if (depth == 0) {
00230         m_statistics.numleafs++;
00231         HexEval value = Evaluate();
00232         return value;
00233     }
00234 
00235     //
00236     // Check for transposition
00237     //
00238     std::string space(3*m_current_depth, ' ');
00239 
00240     m_tt_info_available = false;
00241     m_tt_bestmove = INVALID_POINT;
00242     if (m_tt) 
00243     {
00244         SearchedState state;
00245         if (m_tt->Get(m_brd->GetPosition().Hash(), state)) 
00246         {
00247             m_tt_info_available = true;
00248             m_tt_bestmove = state.move;
00249 
00250             if (state.depth >= depth) 
00251             {
00252                 m_statistics.tt_hits++;
00253 
00254                 LogFine() << space << "--- TT HIT ---" << '\n';
00255 
00256                 if (state.bound == SearchedState::LOWER_BOUND) 
00257                 {
00258                     LogFine() << "Lower Bound" << '\n';
00259                     alpha = std::max(alpha, state.score);
00260                 } 
00261                 else if (state.bound == SearchedState::UPPER_BOUND) 
00262                 {
00263                     LogFine() << "Upper Bound" << '\n';
00264                     beta = std::min(beta, state.score);
00265                 } 
00266                 else if (state.bound == SearchedState::ACCURATE) 
00267                 {
00268                     LogFine() << "Accurate" << '\n';
00269                     alpha = beta = state.score;
00270                 }
00271                 
00272                 LogFine() << "new (alpha, beta): (" << alpha
00273               << ", " << beta << ")" << '\n';
00274                 
00275                 if (alpha >= beta) 
00276                 {
00277                     m_statistics.tt_cuts++;
00278                 
00279                     pv.clear();
00280                     pv.push_back(state.move);
00281                     
00282                     return state.score;
00283                 }
00284             }
00285         }
00286     }
00287 
00288     m_statistics.numinternal++;
00289 
00290     std::vector<HexPoint> moves;
00291     GenerateMoves(moves);
00292     HexAssert(moves.size());
00293     
00294     int curwidth = std::min(plywidth[m_current_depth], (int)moves.size());
00295     m_statistics.mustplay_branches += (int)moves.size();
00296     m_statistics.total_branches += curwidth;
00297 
00298     HexPoint bestmove = INVALID_POINT;
00299     HexEval bestvalue = -EVAL_INFINITY;
00300 
00301     for (int m = 0; !m_aborted && m < curwidth; ++m) 
00302     {
00303         m_statistics.visited_branches++;
00304         LogFine() << space 
00305                  << (m+1) << "/" 
00306                  << curwidth << ": ("
00307                  << m_toplay << ", " << moves[m] << ")"
00308                  << ", (" << alpha << ", " << beta << ")" 
00309                  << '\n';
00310         
00311         ExecuteMove(moves[m]);
00312         m_current_depth++;
00313         m_sequence.push_back(moves[m]);
00314         m_toplay = !m_toplay;
00315 
00316         std::vector<HexPoint> cv;
00317         HexEval value = -SearchState(plywidth, depth-1, -beta, -alpha, cv);
00318 
00319         m_toplay = !m_toplay;
00320         m_sequence.pop_back();
00321         m_current_depth--;
00322         UndoMove(moves[m]);
00323 
00324         if (value > bestvalue) 
00325         {
00326             bestmove = moves[m];
00327             bestvalue = value;
00328 
00329             // compute new principal variation
00330             pv.clear();
00331             pv.push_back(bestmove);
00332             pv.insert(pv.end(), cv.begin(), cv.end());
00333 
00334             LogFine() << space << "--- New best: " << value 
00335                       << " PV: " << HexPointUtil::ToString(pv) << " ---\n";
00336         }
00337 
00338         // store root move evaluations and output progress to gui
00339         if (m_current_depth == 0) 
00340         { 
00341             m_eval.push_back(HexMoveValue(moves[m], value));
00342             if (m_use_guifx)
00343                 DumpGuiFx(m_eval, curwidth, pv, m_toplay);
00344         }
00345 
00346         if (value >= alpha)
00347             alpha = value;
00348 
00349         if (alpha >= beta) 
00350         {
00351             LogFine() << space << "--- Cutoff ---" << '\n';
00352             m_statistics.cuts++;
00353             break;
00354         }
00355     }
00356 
00357     if (m_aborted)
00358         return -EVAL_INFINITY;
00359 
00360     //
00361     // Store in tt
00362     //
00363     HexAssert(bestmove != INVALID_POINT);
00364     if (m_tt) 
00365     {
00366         SearchedState::Bound bound = SearchedState::ACCURATE;
00367         if (bestvalue <= old_alpha) bound = SearchedState::UPPER_BOUND;
00368         if (bestvalue >= old_beta) bound = SearchedState::LOWER_BOUND;
00369         SearchedState ss(m_brd->GetPosition().Hash(), depth, bound, 
00370                          bestvalue, bestmove);
00371         m_tt->Put(m_brd->GetPosition().Hash(), ss);
00372     }
00373 
00374     AfterStateSearched();
00375 
00376     return bestvalue;
00377 }
00378 
00379 std::string HexAbSearch::DumpStats()
00380 {
00381     std::ostringstream os;
00382     os << m_statistics.Dump() << '\n';
00383 
00384     std::vector<HexMoveValue> root_evals(m_eval);
00385     stable_sort(root_evals.begin(), root_evals.end(),
00386                 std::greater<HexMoveValue>());
00387 
00388     os << '\n';
00389     std::size_t num = 10;
00390     for (std::size_t i=0; i<num && i<root_evals.size(); ++i) {
00391         if (i && i%5==0)
00392             os << '\n';
00393         os << "(" 
00394            << root_evals[i].point() << ", " 
00395            << std::fixed << std::setprecision(3) << root_evals[i].value() 
00396            << ") ";
00397     }
00398     os << '\n';
00399 
00400     return os.str();
00401 }
00402 
00403 std::string HexAbSearch::Statistics::Dump() const
00404 {
00405     std::ostringstream os;
00406     os << '\n'
00407        << "        Leaf Nodes: " << numleafs << '\n'
00408        << "    Terminal Nodes: " << numterminal << '\n'
00409        << "    Internal Nodes: " << numinternal << '\n'
00410        << "       Total Nodes: " << numstates << '\n'
00411        << "           TT Hits: " << tt_hits << '\n'
00412        << "           TT Cuts: " << tt_cuts << '\n'
00413        << "Avg. Mustplay Size: " << std::setprecision(4) 
00414        << (double)mustplay_branches / numinternal << '\n'
00415        << "Avg. Branch Factor: " << std::setprecision(4) 
00416        << (double)total_branches / numinternal << '\n'
00417        << "       Avg. To Cut: " << std::setprecision(4) 
00418        << (double)visited_branches / numinternal << '\n'
00419        << "         Nodes/Sec: " << std::setprecision(4) 
00420        << (numstates/elapsed_time) << '\n'
00421        << "      Elapsed Time: " << std::setprecision(4) 
00422        << elapsed_time << "s" << '\n'
00423        << '\n'
00424        << DumpPV(value, pv);
00425     return os.str();
00426 }
00427 
00428 //----------------------------------------------------------------------------


6 Jan 2011 Doxygen 1.6.3