00001 //---------------------------------------------------------------------------- 00002 /** @file HexUctSearch.cpp 00003 */ 00004 //---------------------------------------------------------------------------- 00005 00006 #include "SgSystem.h" 00007 00008 #include "SgException.h" 00009 #include "SgNode.h" 00010 #include "SgMove.h" 00011 #include "SgSList.h" 00012 00013 #include "BoardUtils.hpp" 00014 #include "BitsetIterator.hpp" 00015 #include "HexSgUtil.hpp" 00016 #include "HexUctPolicy.hpp" 00017 #include "HexUctSearch.hpp" 00018 #include "HexUctState.hpp" 00019 #include "HexUctUtil.hpp" 00020 #include "PatternState.hpp" 00021 00022 using namespace benzene; 00023 00024 //---------------------------------------------------------------------------- 00025 00026 HexThreadStateFactory::HexThreadStateFactory(HexUctSharedPolicy* shared) 00027 : m_shared_policy(shared) 00028 { 00029 } 00030 00031 HexThreadStateFactory::~HexThreadStateFactory() 00032 { 00033 } 00034 00035 SgUctThreadState* 00036 HexThreadStateFactory::Create(unsigned int threadId, const SgUctSearch& search) 00037 { 00038 SgUctSearch& srch = const_cast<SgUctSearch&>(search); 00039 HexUctSearch& hexSearch = dynamic_cast<HexUctSearch&>(srch); 00040 LogInfo() << "Creating thread " << threadId << '\n'; 00041 HexUctState* state = new HexUctState(threadId, hexSearch, 00042 hexSearch.TreeUpdateRadius(), 00043 hexSearch.PlayoutUpdateRadius()); 00044 state->SetPolicy(new HexUctPolicy(m_shared_policy)); 00045 return state; 00046 } 00047 00048 //---------------------------------------------------------------------------- 00049 00050 HexUctSearch::HexUctSearch(SgUctThreadStateFactory* factory, int maxMoves) 00051 : SgUctSearch(factory, maxMoves), 00052 m_keepGames(false), 00053 m_liveGfx(false), 00054 m_liveGfxInterval(20000), 00055 m_treeUpdateRadius(2), 00056 m_playoutUpdateRadius(1), 00057 m_brd(0), 00058 m_shared_data(), 00059 m_root(0) 00060 { 00061 SetBiasTermConstant(0.0); 00062 SetExpandThreshold(1); 00063 { 00064 std::vector<SgUctValue> thresholds; 00065 thresholds.push_back(400); 00066 SetKnowledgeThreshold(thresholds); 00067 } 00068 SetLockFree(true); 00069 SetMaxNodes(15000000); 00070 SetMoveSelect(SG_UCTMOVESELECT_COUNT); 00071 SetNumberThreads(1); 00072 SetRave(true); 00073 SetRandomizeRaveFrequency(20); 00074 SetWeightRaveUpdates(false); 00075 SetRaveWeightInitial(1.0); 00076 SetRaveWeightFinal(20000.0); 00077 } 00078 00079 HexUctSearch::~HexUctSearch() 00080 { 00081 if (m_root != 0) 00082 m_root->DeleteTree(); 00083 m_root = 0; 00084 } 00085 00086 /** Merges last game into the tree of games. */ 00087 void HexUctSearch::AppendGame(const std::vector<SgMove>& sequence) 00088 { 00089 HexAssert(m_root != 0); 00090 SgNode* node = m_root->RightMostSon(); 00091 HexColor color = m_shared_data.root_to_play; 00092 std::vector<SgPoint>::const_iterator it = sequence.begin(); 00093 // Find first move that starts a new variation 00094 for (; it != sequence.end(); ++it) 00095 { 00096 if (!node->HasSon()) 00097 break; 00098 bool found = false; 00099 for (SgNode* child = node->LeftMostSon(); ; 00100 child = child->RightBrother()) 00101 { 00102 HexPoint move = HexSgUtil::SgPointToHexPoint(child->NodeMove(), 00103 m_brd->Height()); 00104 // Found it! Recurse down this branch 00105 if (move == *it) 00106 { 00107 node = child; 00108 found = true; 00109 break; 00110 } 00111 if (!child->HasRightBrother()) 00112 break; 00113 } 00114 // Abort if we need to start a new variation 00115 if (!found) 00116 break; 00117 color = !color; 00118 } 00119 // Add the remainder of the sequence to this node 00120 for (; it != sequence.end(); ++it) 00121 { 00122 SgNode* child = node->NewRightMostSon(); 00123 HexSgUtil::AddMoveToNode(child, color, static_cast<HexPoint>(*it), 00124 m_brd->Height()); 00125 color = !color; 00126 node = child; 00127 } 00128 } 00129 00130 void HexUctSearch::OnStartSearch() 00131 { 00132 HexAssert(m_brd); 00133 if (m_root != 0) 00134 m_root->DeleteTree(); 00135 if (m_keepGames) 00136 { 00137 m_root = new SgNode(); 00138 SgNode* position = m_root->NewRightMostSon(); 00139 HexSgUtil::SetPositionInNode(position, m_brd->GetPosition(), 00140 m_shared_data.root_to_play); 00141 } 00142 // Limit to avoid very long games (no need in Hex) 00143 int size = m_brd->Width() * m_brd->Height(); 00144 int maxGameLength = size+10; 00145 SetMaxGameLength(maxGameLength); 00146 m_lastPositionSearched = m_brd->GetPosition(); 00147 m_nextLiveGfx = m_liveGfxInterval; 00148 } 00149 00150 void HexUctSearch::SaveGames(const std::string& filename) const 00151 { 00152 if (m_root == 0) 00153 throw SgException("No games to save"); 00154 HexSgUtil::WriteSgf(m_root, filename.c_str(), m_brd->Height()); 00155 } 00156 00157 void HexUctSearch::SaveTree(std::ostream& out, int maxDepth) const 00158 { 00159 HexUctUtil::SaveTree(Tree(), m_lastPositionSearched, 00160 m_shared_data.root_to_play, out, maxDepth); 00161 } 00162 00163 void HexUctSearch::OnSearchIteration(SgUctValue gameNumber, 00164 const unsigned int threadId, 00165 const SgUctGameInfo& info) 00166 { 00167 SgUctSearch::OnSearchIteration(gameNumber, threadId, info); 00168 if (m_liveGfx && threadId == 0 && gameNumber > m_nextLiveGfx) 00169 { 00170 m_nextLiveGfx = gameNumber + m_liveGfxInterval; 00171 std::ostringstream os; 00172 os << "gogui-gfx:\n"; 00173 os << "uct\n"; 00174 HexColor initial_toPlay = m_shared_data.root_to_play; 00175 HexUctUtil::GoGuiGfx(*this, 00176 HexUctUtil::ToSgBlackWhite(initial_toPlay), 00177 os); 00178 os << "\n"; 00179 std::cout << os.str(); 00180 std::cout.flush(); 00181 LogFine() << os.str() << '\n'; 00182 } 00183 if (m_root != 0) 00184 { 00185 for (std::size_t i = 0; i < LastGameInfo().m_sequence.size(); ++i) 00186 AppendGame(LastGameInfo().m_sequence[i]); 00187 } 00188 } 00189 00190 SgUctValue HexUctSearch::UnknownEval() const 00191 { 00192 // Note: 0.5 is not a possible value for a Bernoulli variable, better 00193 // use 0? 00194 return 0.5; 00195 } 00196 00197 SgUctValue HexUctSearch::InverseEval(SgUctValue eval) const 00198 { 00199 return (1 - eval); 00200 } 00201 00202 std::string HexUctSearch::MoveString(SgMove move) const 00203 { 00204 return HexPointUtil::ToString(static_cast<HexPoint>(move)); 00205 } 00206 00207 //---------------------------------------------------------------------------- 00208