00001 //---------------------------------------------------------------------------- 00002 /** @file HexUctPolicy.cpp 00003 00004 @todo Pattern statistics are collected for each thread. Add 00005 functionality to combine the stats from each thread before 00006 displaying them. Only do this if pattern statistics are actually 00007 required, obviously. 00008 */ 00009 //---------------------------------------------------------------------------- 00010 00011 #include "Hex.hpp" 00012 #include "Misc.hpp" 00013 #include "PatternState.hpp" 00014 #include "HexUctPolicy.hpp" 00015 00016 #include <boost/filesystem/path.hpp> 00017 00018 using namespace benzene; 00019 00020 //---------------------------------------------------------------------------- 00021 00022 namespace 00023 { 00024 00025 /** Shuffle a vector with the given random number generator. 00026 @todo Refactor this out somewhere. 00027 */ 00028 template<typename T> 00029 void ShuffleVector(std::vector<T>& v, SgRandom& random) 00030 { 00031 for (int i = static_cast<int>(v.size() - 1); i > 0; --i) 00032 { 00033 int j = random.Int(i+1); 00034 std::swap(v[i], v[j]); 00035 } 00036 } 00037 00038 /** Returns true 'percent' of the time. */ 00039 bool PercentChance(int percent, SgRandom& random) 00040 { 00041 if (percent >= 100) 00042 return true; 00043 unsigned int threshold = random.PercentageThreshold(percent); 00044 return random.RandomEvent(threshold); 00045 } 00046 00047 } // annonymous namespace 00048 00049 //---------------------------------------------------------------------------- 00050 00051 HexUctPolicyConfig::HexUctPolicyConfig() 00052 : patternHeuristic(true), 00053 responseHeuristic(false), 00054 pattern_update_radius(1), 00055 pattern_check_percent(100), 00056 response_threshold(100) 00057 { 00058 } 00059 00060 //---------------------------------------------------------------------------- 00061 00062 HexUctSharedPolicy::HexUctSharedPolicy() 00063 : m_config() 00064 { 00065 LogFine() << "--- HexUctSharedPolicy\n"; 00066 LoadPatterns(); 00067 } 00068 00069 HexUctSharedPolicy::~HexUctSharedPolicy() 00070 { 00071 } 00072 00073 void HexUctSharedPolicy::LoadPatterns() 00074 { 00075 using namespace boost::filesystem; 00076 path p = path(ABS_TOP_SRCDIR) / "share" / "mohex-patterns.txt"; 00077 p.normalize(); 00078 LoadPlayPatterns(p.native_file_string()); 00079 } 00080 00081 void HexUctSharedPolicy::LoadPlayPatterns(const std::string& filename) 00082 { 00083 std::vector<Pattern> patterns; 00084 Pattern::LoadPatternsFromFile(filename.c_str(), patterns); 00085 LogInfo() << "HexUctSharedPolicy: Read " << patterns.size() 00086 << " patterns from '" << filename << "'.\n"; 00087 00088 // can only load patterns once! 00089 HexAssert(m_patterns[BLACK].empty()); 00090 00091 for (std::size_t i = 0; i < patterns.size(); ++i) { 00092 Pattern p = patterns[i]; 00093 switch(p.getType()) { 00094 case Pattern::MOHEX: 00095 m_patterns[BLACK].push_back(p); 00096 p.flipColors(); 00097 m_patterns[WHITE].push_back(p); 00098 break; 00099 default: 00100 LogWarning() << "Pattern type = " << p.getType() << '\n'; 00101 HexAssert(false); 00102 } 00103 } 00104 // create the hashed pattern sets for fast checking 00105 for (BWIterator color; color; ++color) 00106 m_hash_patterns[*color].hash(m_patterns[*color]); 00107 } 00108 00109 //---------------------------------------------------------------------------- 00110 00111 HexUctPolicy::HexUctPolicy(const HexUctSharedPolicy* shared) 00112 : m_shared(shared) 00113 #if COLLECT_PATTERN_STATISTICS 00114 , m_statistics() 00115 #endif 00116 { 00117 } 00118 00119 HexUctPolicy::~HexUctPolicy() 00120 { 00121 } 00122 00123 //---------------------------------------------------------------------------- 00124 00125 /** @todo Pass initialial tree and initialize off of that? */ 00126 void HexUctPolicy::InitializeForSearch() 00127 { 00128 for (int i = 0; i < BITSETSIZE; ++i) 00129 { 00130 m_response[BLACK][i].clear(); 00131 m_response[WHITE][i].clear(); 00132 } 00133 } 00134 00135 void HexUctPolicy::InitializeForRollout(const StoneBoard& brd) 00136 { 00137 BitsetUtil::BitsetToVector(brd.GetEmpty(), m_moves); 00138 ShuffleVector(m_moves, m_random); 00139 } 00140 00141 HexPoint HexUctPolicy::GenerateMove(PatternState& pastate, 00142 HexColor toPlay, 00143 HexPoint lastMove) 00144 { 00145 HexPoint move = INVALID_POINT; 00146 bool pattern_move = false; 00147 const HexUctPolicyConfig& config = m_shared->Config(); 00148 #if COLLECT_PATTERN_STATISTICS 00149 HexUctPolicyStatistics& stats = m_statistics; 00150 #endif 00151 00152 // patterns applied probabilistically (if heuristic is turned on) 00153 if (config.patternHeuristic 00154 && PercentChance(config.pattern_check_percent, m_random)) 00155 { 00156 move = GeneratePatternMove(pastate, toPlay, lastMove); 00157 } 00158 00159 if (move == INVALID_POINT 00160 && config.responseHeuristic) 00161 { 00162 move = GenerateResponseMove(toPlay, lastMove, pastate.Board()); 00163 } 00164 00165 // select random move if invalid point from pattern heuristic 00166 if (move == INVALID_POINT) 00167 { 00168 #if COLLECT_PATTERN_STATISTICS 00169 stats.random_moves++; 00170 #endif 00171 move = GenerateRandomMove(pastate.Board()); 00172 } 00173 else 00174 { 00175 pattern_move = true; 00176 #if COLLECT_PATTERN_STATISTICS 00177 stats.pattern_moves++; 00178 #endif 00179 } 00180 00181 HexAssert(pastate.Board().IsEmpty(move)); 00182 #if COLLECT_PATTERN_STATISTICS 00183 stats.total_moves++; 00184 #endif 00185 return move; 00186 } 00187 00188 #if COLLECT_PATTERN_STATISTICS 00189 std::string HexUctPolicy::DumpStatistics() 00190 { 00191 std::ostringstream os; 00192 00193 os << std::endl; 00194 os << "Pattern statistics:" << std::endl; 00195 os << std::setw(12) << "Name" << " " 00196 << std::setw(10) << "Black" << " " 00197 << std::setw(10) << "White" << " " 00198 << std::setw(10) << "Black" << " " 00199 << std::setw(10) << "White" << std::endl; 00200 00201 os << " ------------------------------------------------------" 00202 << std::endl; 00203 00204 HexUctPolicyStatistics& stats = Statistics(); 00205 for (unsigned i=0; i<m_patterns[BLACK].size(); ++i) { 00206 os << std::setw(12) << m_patterns[BLACK][i].getName() << ": " 00207 << std::setw(10) << stats.pattern_counts[BLACK] 00208 [&m_patterns[BLACK][i]] << " " 00209 << std::setw(10) << stats.pattern_counts[WHITE] 00210 [&m_patterns[WHITE][i]] << " " 00211 << std::setw(10) << stats.pattern_picked[BLACK] 00212 [&m_patterns[BLACK][i]] << " " 00213 << std::setw(10) << stats.pattern_picked[WHITE] 00214 [&m_patterns[WHITE][i]] 00215 << std::endl; 00216 } 00217 00218 os << " ------------------------------------------------------" 00219 << std::endl; 00220 00221 os << std::endl; 00222 os << std::setw(12) << "Pattern" << ": " 00223 << std::setw(10) << stats.pattern_moves << " " 00224 << std::setw(10) << std::setprecision(3) << 00225 stats.pattern_moves*100.0/stats.total_moves << "%" 00226 << std::endl; 00227 os << std::setw(12) << "Random" << ": " 00228 << std::setw(10) << stats.random_moves << " " 00229 << std::setw(10) << std::setprecision(3) << 00230 stats.random_moves*100.0/stats.total_moves << "%" 00231 << std::endl; 00232 os << std::setw(12) << "Total" << ": " 00233 << std::setw(10) << stats.total_moves << std::endl; 00234 00235 os << std::endl; 00236 00237 return os.str(); 00238 } 00239 #endif 00240 00241 //-------------------------------------------------------------------------- 00242 00243 HexPoint HexUctPolicy::GenerateResponseMove(HexColor toPlay, HexPoint lastMove, 00244 const StoneBoard& brd) 00245 { 00246 std::size_t num = m_response[toPlay][lastMove].size(); 00247 if (num > m_shared->Config().response_threshold) 00248 { 00249 HexPoint move = m_response[toPlay][lastMove][m_random.Int(num)]; 00250 if (brd.IsEmpty(move)) 00251 return move; 00252 } 00253 return INVALID_POINT; 00254 } 00255 00256 /** Selects random move among the empty cells on the board. */ 00257 HexPoint HexUctPolicy::GenerateRandomMove(const StoneBoard& brd) 00258 { 00259 HexPoint ret = INVALID_POINT; 00260 while (true) 00261 { 00262 HexAssert(!m_moves.empty()); 00263 ret = m_moves.back(); 00264 m_moves.pop_back(); 00265 if (brd.IsEmpty(ret)) 00266 break; 00267 } 00268 return ret; 00269 } 00270 00271 /** Randomly picks a pattern move from the set of patterns that hit 00272 the last move, weighted by the pattern's weight. 00273 If no pattern matches, returns INVALID_POINT. */ 00274 HexPoint HexUctPolicy::PickRandomPatternMove(const PatternState& pastate, 00275 const HashedPatternSet& patterns, 00276 HexColor toPlay, 00277 HexPoint lastMove) 00278 { 00279 UNUSED(toPlay); 00280 00281 if (lastMove == INVALID_POINT) 00282 return INVALID_POINT; 00283 00284 int num = 0; 00285 int patternIndex[MAX_VOTES]; 00286 HexPoint patternMoves[MAX_VOTES]; 00287 00288 PatternHits hits; 00289 pastate.MatchOnCell(patterns, lastMove, PatternState::MATCH_ALL, hits); 00290 00291 for (unsigned i = 0; i < hits.size(); ++i) 00292 { 00293 #if COLLECT_PATTERN_STATISTICS 00294 // record that this pattern hit 00295 m_shared->Statistics().pattern_counts[toPlay][hits[i].pattern()]++; 00296 #endif 00297 00298 // number of entries added to array is equal to the pattern's weight 00299 for (int j = 0; j < hits[i].pattern()->getWeight(); ++j) 00300 { 00301 patternIndex[num] = i; 00302 patternMoves[num] = hits[i].moves1()[0]; 00303 num++; 00304 HexAssert(num < MAX_VOTES); 00305 } 00306 } 00307 00308 // abort if no pattern hit 00309 if (num == 0) 00310 return INVALID_POINT; 00311 00312 // select move at random (biased according to number of entries) 00313 int i = m_random.Int(num); 00314 00315 #if COLLECT_PATTERN_STATISTICS 00316 m_shared->Statistics().pattern_picked 00317 [toPlay][hits[patternIndex[i]].pattern()]++; 00318 #endif 00319 00320 return patternMoves[i]; 00321 } 00322 00323 /** Uses PickRandomPatternMove() with the shared PlayPatterns(). */ 00324 HexPoint HexUctPolicy::GeneratePatternMove(const PatternState& pastate, 00325 HexColor toPlay, 00326 HexPoint lastMove) 00327 { 00328 return PickRandomPatternMove(pastate, m_shared->PlayPatterns(toPlay), 00329 toPlay, lastMove); 00330 } 00331 00332 //----------------------------------------------------------------------------