Main   Namespaces   Classes   Hierarchy   Annotated   Files   Compound   Global   Pages  

HexUctPolicy.cpp

Go to the documentation of this file.
00001 //----------------------------------------------------------------------------
00002 /** @file HexUctPolicy.cpp
00003     
00004     @todo Pattern statistics are collected for each thread. Add
00005     functionality to combine the stats from each thread before
00006     displaying them. Only do this if pattern statistics are actually
00007     required, obviously.
00008  */
00009 //----------------------------------------------------------------------------
00010 
00011 #include "Hex.hpp"
00012 #include "Misc.hpp"
00013 #include "PatternState.hpp"
00014 #include "HexUctPolicy.hpp"
00015 
00016 #include <boost/filesystem/path.hpp>
00017 
00018 using namespace benzene;
00019 
00020 //----------------------------------------------------------------------------
00021 
00022 namespace 
00023 {
00024 
00025 /** Shuffle a vector with the given random number generator. 
00026     @todo Refactor this out somewhere.
00027 */
00028 template<typename T>
00029 void ShuffleVector(std::vector<T>& v, SgRandom& random)
00030 {
00031     for (int i = static_cast<int>(v.size() - 1); i > 0; --i) 
00032     {
00033         int j = random.Int(i+1);
00034         std::swap(v[i], v[j]);
00035     }
00036 }
00037 
00038 /** Returns true 'percent' of the time. */
00039 bool PercentChance(int percent, SgRandom& random)
00040 {
00041     if (percent >= 100) 
00042         return true;
00043     unsigned int threshold = random.PercentageThreshold(percent);
00044     return random.RandomEvent(threshold);
00045 }
00046 
00047 } // annonymous namespace
00048 
00049 //----------------------------------------------------------------------------
00050 
00051 HexUctPolicyConfig::HexUctPolicyConfig()
00052     : patternHeuristic(true),
00053       responseHeuristic(false),
00054       pattern_update_radius(1),
00055       pattern_check_percent(100),
00056       response_threshold(100)
00057 {
00058 }
00059 
00060 //----------------------------------------------------------------------------
00061 
00062 HexUctSharedPolicy::HexUctSharedPolicy()
00063     : m_config()
00064 {
00065     LogFine() << "--- HexUctSharedPolicy\n";
00066     LoadPatterns();
00067 }
00068 
00069 HexUctSharedPolicy::~HexUctSharedPolicy()
00070 {
00071 }
00072 
00073 void HexUctSharedPolicy::LoadPatterns()
00074 {
00075     using namespace boost::filesystem;
00076     path p = path(ABS_TOP_SRCDIR) / "share" / "mohex-patterns.txt";
00077     p.normalize();
00078     LoadPlayPatterns(p.native_file_string());
00079 }
00080 
00081 void HexUctSharedPolicy::LoadPlayPatterns(const std::string& filename)
00082 {
00083     std::vector<Pattern> patterns;
00084     Pattern::LoadPatternsFromFile(filename.c_str(), patterns);
00085     LogInfo() << "HexUctSharedPolicy: Read " << patterns.size()
00086           << " patterns from '" << filename << "'.\n";
00087 
00088     // can only load patterns once!
00089     HexAssert(m_patterns[BLACK].empty());
00090 
00091     for (std::size_t i = 0; i < patterns.size(); ++i) {
00092         Pattern p = patterns[i];
00093         switch(p.getType()) {
00094         case Pattern::MOHEX:
00095             m_patterns[BLACK].push_back(p);
00096             p.flipColors();
00097             m_patterns[WHITE].push_back(p);
00098             break;
00099         default:
00100             LogWarning() << "Pattern type = " << p.getType() << '\n';
00101             HexAssert(false);
00102         }
00103     }
00104     // create the hashed pattern sets for fast checking
00105     for (BWIterator color; color; ++color)
00106         m_hash_patterns[*color].hash(m_patterns[*color]);
00107 }
00108 
00109 //----------------------------------------------------------------------------
00110 
00111 HexUctPolicy::HexUctPolicy(const HexUctSharedPolicy* shared)
00112     : m_shared(shared)
00113 #if COLLECT_PATTERN_STATISTICS
00114     , m_statistics()
00115 #endif
00116 {
00117 }
00118 
00119 HexUctPolicy::~HexUctPolicy()
00120 {
00121 }
00122 
00123 //----------------------------------------------------------------------------
00124 
00125 /** @todo Pass initialial tree and initialize off of that? */
00126 void HexUctPolicy::InitializeForSearch()
00127 {
00128     for (int i = 0; i < BITSETSIZE; ++i)
00129     {
00130         m_response[BLACK][i].clear();
00131         m_response[WHITE][i].clear();
00132     }
00133 }
00134 
00135 void HexUctPolicy::InitializeForRollout(const StoneBoard& brd)
00136 {
00137     BitsetUtil::BitsetToVector(brd.GetEmpty(), m_moves);
00138     ShuffleVector(m_moves, m_random);
00139 }
00140 
00141 HexPoint HexUctPolicy::GenerateMove(PatternState& pastate, 
00142                                     HexColor toPlay, 
00143                                     HexPoint lastMove)
00144 {
00145     HexPoint move = INVALID_POINT;
00146     bool pattern_move = false;
00147     const HexUctPolicyConfig& config = m_shared->Config();
00148 #if COLLECT_PATTERN_STATISTICS
00149     HexUctPolicyStatistics& stats = m_statistics;
00150 #endif
00151 
00152     // patterns applied probabilistically (if heuristic is turned on)
00153     if (config.patternHeuristic 
00154         && PercentChance(config.pattern_check_percent, m_random))
00155     {
00156         move = GeneratePatternMove(pastate, toPlay, lastMove);
00157     }
00158     
00159     if (move == INVALID_POINT
00160         && config.responseHeuristic)
00161     {
00162         move = GenerateResponseMove(toPlay, lastMove, pastate.Board());
00163     }
00164 
00165     // select random move if invalid point from pattern heuristic
00166     if (move == INVALID_POINT) 
00167     {
00168 #if COLLECT_PATTERN_STATISTICS
00169     stats.random_moves++;
00170 #endif
00171         move = GenerateRandomMove(pastate.Board());
00172     } 
00173     else 
00174     {
00175     pattern_move = true;
00176 #if COLLECT_PATTERN_STATISTICS
00177         stats.pattern_moves++;
00178 #endif
00179     }
00180     
00181     HexAssert(pastate.Board().IsEmpty(move));
00182 #if COLLECT_PATTERN_STATISTICS
00183     stats.total_moves++;
00184 #endif
00185     return move;
00186 }
00187 
00188 #if COLLECT_PATTERN_STATISTICS
00189 std::string HexUctPolicy::DumpStatistics()
00190 {
00191     std::ostringstream os;
00192 
00193     os << std::endl;
00194     os << "Pattern statistics:" << std::endl;
00195     os << std::setw(12) << "Name" << "  "
00196        << std::setw(10) << "Black" << " "
00197        << std::setw(10) << "White" << " "
00198        << std::setw(10) << "Black" << " "
00199        << std::setw(10) << "White" << std::endl;
00200 
00201     os << "     ------------------------------------------------------" 
00202        << std::endl;
00203 
00204     HexUctPolicyStatistics& stats = Statistics();
00205     for (unsigned i=0; i<m_patterns[BLACK].size(); ++i) {
00206         os << std::setw(12) << m_patterns[BLACK][i].getName() << ": "
00207            << std::setw(10) << stats.pattern_counts[BLACK]
00208             [&m_patterns[BLACK][i]] << " "
00209            << std::setw(10) << stats.pattern_counts[WHITE]
00210             [&m_patterns[WHITE][i]] << " " 
00211            << std::setw(10) << stats.pattern_picked[BLACK]
00212             [&m_patterns[BLACK][i]] << " "
00213            << std::setw(10) << stats.pattern_picked[WHITE]
00214             [&m_patterns[WHITE][i]]
00215            << std::endl;
00216     }
00217 
00218     os << "     ------------------------------------------------------" 
00219        << std::endl;
00220 
00221     os << std::endl;
00222     os << std::setw(12) << "Pattern" << ": " 
00223        << std::setw(10) << stats.pattern_moves << " "
00224        << std::setw(10) << std::setprecision(3) << 
00225         stats.pattern_moves*100.0/stats.total_moves << "%" 
00226        << std::endl;
00227     os << std::setw(12) << "Random" << ": " 
00228        << std::setw(10) << stats.random_moves << " "
00229        << std::setw(10) << std::setprecision(3) << 
00230         stats.random_moves*100.0/stats.total_moves << "%"  
00231        << std::endl;
00232     os << std::setw(12) << "Total" << ": " 
00233        << std::setw(10) << stats.total_moves << std::endl;
00234 
00235     os << std::endl;
00236     
00237     return os.str();
00238 }
00239 #endif
00240 
00241 //--------------------------------------------------------------------------
00242 
00243 HexPoint HexUctPolicy::GenerateResponseMove(HexColor toPlay, HexPoint lastMove,
00244                                             const StoneBoard& brd)
00245 {
00246     std::size_t num = m_response[toPlay][lastMove].size();
00247     if (num > m_shared->Config().response_threshold)
00248     {
00249         HexPoint move = m_response[toPlay][lastMove][m_random.Int(num)];
00250         if (brd.IsEmpty(move))
00251             return move;
00252     }
00253     return INVALID_POINT;
00254 }
00255 
00256 /** Selects random move among the empty cells on the board. */
00257 HexPoint HexUctPolicy::GenerateRandomMove(const StoneBoard& brd)
00258 {
00259     HexPoint ret = INVALID_POINT;
00260     while (true) 
00261     {
00262     HexAssert(!m_moves.empty());
00263         ret = m_moves.back();
00264         m_moves.pop_back();
00265         if (brd.IsEmpty(ret))
00266             break;
00267     }
00268     return ret;
00269 }
00270 
00271 /** Randomly picks a pattern move from the set of patterns that hit
00272     the last move, weighted by the pattern's weight. 
00273     If no pattern matches, returns INVALID_POINT. */
00274 HexPoint HexUctPolicy::PickRandomPatternMove(const PatternState& pastate, 
00275                                              const HashedPatternSet& patterns, 
00276                                              HexColor toPlay,
00277                                              HexPoint lastMove)
00278 {
00279     UNUSED(toPlay);
00280 
00281     if (lastMove == INVALID_POINT)
00282     return INVALID_POINT;
00283     
00284     int num = 0;
00285     int patternIndex[MAX_VOTES];
00286     HexPoint patternMoves[MAX_VOTES];
00287 
00288     PatternHits hits;
00289     pastate.MatchOnCell(patterns, lastMove, PatternState::MATCH_ALL, hits);
00290 
00291     for (unsigned i = 0; i < hits.size(); ++i) 
00292     {
00293 #if COLLECT_PATTERN_STATISTICS
00294         // record that this pattern hit
00295         m_shared->Statistics().pattern_counts[toPlay][hits[i].pattern()]++;
00296 #endif
00297             
00298         // number of entries added to array is equal to the pattern's weight
00299         for (int j = 0; j < hits[i].pattern()->getWeight(); ++j) 
00300         {
00301             patternIndex[num] = i;
00302             patternMoves[num] = hits[i].moves1()[0];
00303             num++;
00304             HexAssert(num < MAX_VOTES);
00305         }
00306     }
00307     
00308     // abort if no pattern hit
00309     if (num == 0) 
00310         return INVALID_POINT;
00311     
00312     // select move at random (biased according to number of entries)
00313     int i = m_random.Int(num);
00314 
00315 #if COLLECT_PATTERN_STATISTICS
00316     m_shared->Statistics().pattern_picked
00317         [toPlay][hits[patternIndex[i]].pattern()]++;
00318 #endif
00319 
00320     return patternMoves[i];
00321 }
00322 
00323 /** Uses PickRandomPatternMove() with the shared PlayPatterns(). */
00324 HexPoint HexUctPolicy::GeneratePatternMove(const PatternState& pastate, 
00325                                            HexColor toPlay, 
00326                                            HexPoint lastMove)
00327 {
00328     return PickRandomPatternMove(pastate, m_shared->PlayPatterns(toPlay),
00329                                  toPlay, lastMove);
00330 }
00331 
00332 //----------------------------------------------------------------------------


6 Jan 2011 Doxygen 1.6.3