HexUctSearch.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006 #include "SgSystem.h"
00007
00008 #include "SgException.h"
00009 #include "SgNode.h"
00010 #include "SgMove.h"
00011 #include "SgSList.h"
00012
00013 #include "BoardUtils.hpp"
00014 #include "BitsetIterator.hpp"
00015 #include "HexSgUtil.hpp"
00016 #include "HexUctPolicy.hpp"
00017 #include "HexUctSearch.hpp"
00018 #include "HexUctState.hpp"
00019 #include "HexUctUtil.hpp"
00020 #include "PatternState.hpp"
00021
00022 using namespace benzene;
00023
00024
00025
00026 HexThreadStateFactory::HexThreadStateFactory(HexUctSharedPolicy* shared)
00027 : m_shared_policy(shared)
00028 {
00029 }
00030
00031 HexThreadStateFactory::~HexThreadStateFactory()
00032 {
00033 }
00034
00035 SgUctThreadState*
00036 HexThreadStateFactory::Create(unsigned int threadId, const SgUctSearch& search)
00037 {
00038 SgUctSearch& srch = const_cast<SgUctSearch&>(search);
00039 HexUctSearch& hexSearch = dynamic_cast<HexUctSearch&>(srch);
00040 LogInfo() << "Creating thread " << threadId << '\n';
00041 HexUctState* state = new HexUctState(threadId, hexSearch,
00042 hexSearch.TreeUpdateRadius(),
00043 hexSearch.PlayoutUpdateRadius());
00044 state->SetPolicy(new HexUctPolicy(m_shared_policy));
00045 return state;
00046 }
00047
00048
00049
00050 HexUctSearch::HexUctSearch(SgUctThreadStateFactory* factory, int maxMoves)
00051 : SgUctSearch(factory, maxMoves),
00052 m_keepGames(false),
00053 m_liveGfx(false),
00054 m_liveGfxInterval(20000),
00055 m_treeUpdateRadius(2),
00056 m_playoutUpdateRadius(1),
00057 m_brd(0),
00058 m_shared_data(),
00059 m_root(0)
00060 {
00061 SetBiasTermConstant(0.0);
00062 SetExpandThreshold(1);
00063 {
00064 std::vector<SgUctValue> thresholds;
00065 thresholds.push_back(400);
00066 SetKnowledgeThreshold(thresholds);
00067 }
00068 SetLockFree(true);
00069 SetMaxNodes(15000000);
00070 SetMoveSelect(SG_UCTMOVESELECT_COUNT);
00071 SetNumberThreads(1);
00072 SetRave(true);
00073 SetRandomizeRaveFrequency(20);
00074 SetWeightRaveUpdates(false);
00075 SetRaveWeightInitial(1.0);
00076 SetRaveWeightFinal(20000.0);
00077 }
00078
00079 HexUctSearch::~HexUctSearch()
00080 {
00081 if (m_root != 0)
00082 m_root->DeleteTree();
00083 m_root = 0;
00084 }
00085
00086
00087 void HexUctSearch::AppendGame(const std::vector<SgMove>& sequence)
00088 {
00089 HexAssert(m_root != 0);
00090 SgNode* node = m_root->RightMostSon();
00091 HexColor color = m_shared_data.root_to_play;
00092 std::vector<SgPoint>::const_iterator it = sequence.begin();
00093
00094 for (; it != sequence.end(); ++it)
00095 {
00096 if (!node->HasSon())
00097 break;
00098 bool found = false;
00099 for (SgNode* child = node->LeftMostSon(); ;
00100 child = child->RightBrother())
00101 {
00102 HexPoint move = HexSgUtil::SgPointToHexPoint(child->NodeMove(),
00103 m_brd->Height());
00104
00105 if (move == *it)
00106 {
00107 node = child;
00108 found = true;
00109 break;
00110 }
00111 if (!child->HasRightBrother())
00112 break;
00113 }
00114
00115 if (!found)
00116 break;
00117 color = !color;
00118 }
00119
00120 for (; it != sequence.end(); ++it)
00121 {
00122 SgNode* child = node->NewRightMostSon();
00123 HexSgUtil::AddMoveToNode(child, color, static_cast<HexPoint>(*it),
00124 m_brd->Height());
00125 color = !color;
00126 node = child;
00127 }
00128 }
00129
00130 void HexUctSearch::OnStartSearch()
00131 {
00132 HexAssert(m_brd);
00133 if (m_root != 0)
00134 m_root->DeleteTree();
00135 if (m_keepGames)
00136 {
00137 m_root = new SgNode();
00138 SgNode* position = m_root->NewRightMostSon();
00139 HexSgUtil::SetPositionInNode(position, m_brd->GetPosition(),
00140 m_shared_data.root_to_play);
00141 }
00142
00143 int size = m_brd->Width() * m_brd->Height();
00144 int maxGameLength = size+10;
00145 SetMaxGameLength(maxGameLength);
00146 m_lastPositionSearched = m_brd->GetPosition();
00147 m_nextLiveGfx = m_liveGfxInterval;
00148 }
00149
00150 void HexUctSearch::SaveGames(const std::string& filename) const
00151 {
00152 if (m_root == 0)
00153 throw SgException("No games to save");
00154 HexSgUtil::WriteSgf(m_root, filename.c_str(), m_brd->Height());
00155 }
00156
00157 void HexUctSearch::SaveTree(std::ostream& out, int maxDepth) const
00158 {
00159 HexUctUtil::SaveTree(Tree(), m_lastPositionSearched,
00160 m_shared_data.root_to_play, out, maxDepth);
00161 }
00162
00163 void HexUctSearch::OnSearchIteration(SgUctValue gameNumber,
00164 const unsigned int threadId,
00165 const SgUctGameInfo& info)
00166 {
00167 SgUctSearch::OnSearchIteration(gameNumber, threadId, info);
00168 if (m_liveGfx && threadId == 0 && gameNumber > m_nextLiveGfx)
00169 {
00170 m_nextLiveGfx = gameNumber + m_liveGfxInterval;
00171 std::ostringstream os;
00172 os << "gogui-gfx:\n";
00173 os << "uct\n";
00174 HexColor initial_toPlay = m_shared_data.root_to_play;
00175 HexUctUtil::GoGuiGfx(*this,
00176 HexUctUtil::ToSgBlackWhite(initial_toPlay),
00177 os);
00178 os << "\n";
00179 std::cout << os.str();
00180 std::cout.flush();
00181 LogFine() << os.str() << '\n';
00182 }
00183 if (m_root != 0)
00184 {
00185 for (std::size_t i = 0; i < LastGameInfo().m_sequence.size(); ++i)
00186 AppendGame(LastGameInfo().m_sequence[i]);
00187 }
00188 }
00189
00190 SgUctValue HexUctSearch::UnknownEval() const
00191 {
00192
00193
00194 return 0.5;
00195 }
00196
00197 SgUctValue HexUctSearch::InverseEval(SgUctValue eval) const
00198 {
00199 return (1 - eval);
00200 }
00201
00202 std::string HexUctSearch::MoveString(SgMove move) const
00203 {
00204 return HexPointUtil::ToString(static_cast<HexPoint>(move));
00205 }
00206
00207
00208