D:/simple_rts/include/AStarSearch.h

Go to the documentation of this file.
00001 #ifndef __ASTAR_SEARCH_H__
00002 #define __ASTAR_SEARCH_H__
00003 
00004 #pragma once
00005 
00006 //#define _USE_HEAP_
00007 
00008 #include "WaypointList.h"
00009 
00010 #include <cassert>
00011 #include <vector>
00012 #include <exception>
00013 #include <algorithm>
00014 #include <iostream>
00015 
00016 using std::cout;
00017 using std::endl;
00018 
00019 using std::vector;
00020 using std::exception;
00021 using std::find;
00022 
00023 
00024 /*
00025         Implement better searching algorithms instead of using an iterator inside of
00026                 _isNodeOpen
00027                 _isNodeClosed
00028 
00029 
00030         A UserState class must implement the follow methods
00031                 getNumNeighbours
00032                         Returns the number of neighbours this UserState has
00033                 
00034                 getNeigbhour
00035                         Returns a particular UserState based on an index
00036                 
00037                 getTraversalCost
00038                         How much will it "cost" to use this UserState in the solution
00039                 
00040                 getEstimatedCostToState
00041                         Approximately how far is the EndState from this current UserState
00042                 
00043                 operator==
00044                         Compares two UserStates for equality
00045 
00046 
00047 */
00048 
00049 // ----------------------------------------------------------------------------
00050 template <class UserState>
00051 class AStarSearch
00052 {
00053         // ----------------------------------------------------------------------------
00054         class AStarNode
00055         {
00056         public:
00057                 // TODO
00058                 //              Make sure that this becomes a constant reference
00059                 UserState*              mUserState;
00060 
00061                 // The node that is currently being used to "enter" into this node
00062                 AStarNode*              mSuccessor;
00063 
00064                 // Data Storage
00065                 // ----------------------------------------------------------------------------
00066         private:
00067                 float           mEntryCost;
00068                 float           mEstimatedGoalCost;
00069 
00070                 // TODO
00071                 //              Seriously implement this better
00072                 AStarNode*              mNeighbours[8];
00073 
00074                 static int                      mNumNodes;
00075                 unsigned int            mID;
00076 
00077 
00078         // Construction
00079         // ----------------------------------------------------------------------------
00080         public:
00081                 AStarNode ( UserState* state )
00082                         : mSuccessor ( NULL ), mUserState ( state),
00083                           mEntryCost ( 0.0 ), mEstimatedGoalCost ( 0.0 )
00084                 {
00085                         mID = mNumNodes++;
00086 
00087                         // Create new nodes for each of the neighbours of this one
00088                         for ( int i = 0; i < mUserState->getNumNeighbours(); i++ )
00089                                 mNeighbours[i] = NULL;
00090 
00091                         //cout << "Node Constructed: " << ++nodeCount << endl;\e
00092                 }
00093 
00094 
00095                 // ----------------------------------------------------------------------------
00096                 ~AStarNode ()
00097                 {
00098                         //cout << "Node released: " << --nodeCount << endl;
00099                         // Our node may not have any neighbours created yet
00100                         if ( mUserState == NULL )
00101                                 return;
00102 
00103 
00104                         for ( int i = 0; i < mUserState->getNumNeighbours(); i++ )
00105                                 delete mNeighbours[i];
00106                 }
00107 
00108 
00109                 // ----------------------------------------------------------------------------
00110                 void createNeighbourNodes ()
00111                 {
00112                         for ( int i = 0; i < mUserState->getNumNeighbours(); i++ )
00113                                 mNeighbours[i] = new AStarNode( mUserState->getNeighbour(i) );
00114                 }
00115 
00116 
00117                 // ----------------------------------------------------------------------------
00118                 bool operator== ( AStarNode& rhs ) const
00119                 {
00120                         return ( mUserState == rhs.mUserState );
00121                 }
00122 
00123 
00124                 // ----------------------------------------------------------------------------
00125                 bool operator== ( const AStarNode& rhs ) const
00126                 {
00127                         return ( mUserState == rhs.mUserState );
00128                 }
00129 
00130                 
00131         // Retrieval
00132         // ----------------------------------------------------------------------------
00133         public:
00134                 float getTotalCost () const
00135                 {
00136                         return mEntryCost + mEstimatedGoalCost;
00137                 }
00138 
00139                 // ----------------------------------------------------------------------------
00140                 float getEntryCost () const
00141                 {
00142                         return mEntryCost;
00143                 }
00144 
00145                 // ----------------------------------------------------------------------------
00146                 float getEstimatedGoalCost () const
00147                 {
00148                         return mEstimatedGoalCost;
00149                 }
00150 
00151                 // ----------------------------------------------------------------------------
00152                 float getTraversalCost () const
00153                 {
00154                         return mUserState->getTraversalCost();
00155                 }
00156 
00157                 // ----------------------------------------------------------------------------
00158                 AStarNode* getNeighbour ( int index )
00159                 {
00160                         assert ( index < 8 );
00161 
00162                         return mNeighbours[index];
00163                 }
00164 
00165 
00166         // Storing
00167         // ----------------------------------------------------------------------------
00168         public:
00169                 void setEntryCost ( float entryCost )
00170                 {
00171                         mEntryCost = entryCost;
00172                 }
00173 
00174                 // ----------------------------------------------------------------------------
00175                 void setEstimatedGoalCost ( float cost )
00176                 {
00177                         mEstimatedGoalCost = cost;
00178                 }
00179 
00180                 // ----------------------------------------------------------------------------
00181                 void setSuccessor ( AStarNode* parent )
00182                 {
00183                         mSuccessor = parent;
00184                 }
00185         };
00186 
00187 
00188 
00189 // Data Storage
00190 // ----------------------------------------------------------------------------
00191 public:
00192         enum SearchState
00193         {
00194                 SS_NOT_STARTED,
00195                 SS_RUNNING,
00196                 SS_PATH_NOT_FOUND,
00197                 SS_PATH_FOUND
00198         };
00199 
00200 private:
00201         SearchState                     mCurState;
00202         AStarNode*                      mStartNode;
00203         UserState*                      mEndState;
00204 
00205         AStarNode*                      mGoalNode;
00206 
00207         vector<AStarNode*>      mOpenList;
00208         vector<AStarNode*>      mClosedList;
00209 
00210         WaypointList<UserState>*        mSolutionList;
00211 
00212 
00213         // ----------------------------------------------------------------------------
00214         public:
00215                 class HeapCompare
00216                 {
00217                 public:
00218 
00219                         bool operator() ( const AStarNode* x, const AStarNode* y ) const
00220                         {
00221                                 /*
00222                                 if ( fabs( x->getTotalCost () - y->getTotalCost () ) < 0.01 )
00223                                         return true;
00224                                 */
00225 
00226                                 return x->getTotalCost() > y->getTotalCost();
00227                         }
00228                 };
00229 
00230 // Construction
00231 // ----------------------------------------------------------------------------
00232 public:
00233         AStarSearch(void)
00234         {
00235                 mCurState = SS_NOT_STARTED;
00236 
00237                 mGoalNode = NULL;
00238                 mStartNode = NULL;
00239 
00240                 mSolutionList = NULL;
00241 
00242                 mOpenList.clear ();
00243                 mClosedList.clear ();
00244         }
00245 
00246         // ----------------------------------------------------------------------------
00247         ~AStarSearch(void)
00248         {
00249                 _releaseNodes ();
00250                 _releaseSolution ();
00251         }
00252 
00253 
00254         // ----------------------------------------------------------------------------
00255         void _releaseSolution ()
00256         {
00257                 delete mSolutionList;
00258                 mSolutionList = NULL;
00259         }
00260 
00261 
00262 // Setup
00263 // ----------------------------------------------------------------------------
00264 public:
00265         void setStartState ( UserState* startState, UserState* endState )
00266         {
00267                 _releaseNodes ();
00268                 _releaseSolution ();
00269 
00270                 mStartNode = NULL;
00271                 mEndState = NULL;
00272 
00273                 if ( fabs(endState->getTraversalCost() + 1.0f) < 0.01f )
00274                 {
00275                         mCurState = SS_PATH_NOT_FOUND;
00276                         return;
00277                 }
00278 
00279                 mStartNode = new AStarNode ( startState );
00280                 mStartNode->createNeighbourNodes ();
00281                 mStartNode->setEstimatedGoalCost ( mStartNode->mUserState->getEstimatedCostToState ( *endState ) );
00282 
00283                 mEndState = endState;
00284 
00285                 mOpenList.push_back ( mStartNode );
00286         }
00287 
00288 
00289 // Stepping
00290 // ----------------------------------------------------------------------------
00291 public:
00292         SearchState getSearchState ()
00293         {
00294                 return mCurState;
00295         }
00296 
00300         bool isRunnable ()
00301         {
00302                 return (mCurState == SS_RUNNING || mCurState == SS_NOT_STARTED );
00303         }
00304 
00308         bool isSolved ()
00309         {
00310                 return (mCurState == SS_PATH_FOUND );
00311         }
00312 
00313 
00317         bool isUnsolvable ()
00318         {
00319                 return (mCurState == SS_PATH_NOT_FOUND );
00320         }
00321 
00322 
00326         bool advanceSearch ()
00327         {
00328                 if ( mCurState == SS_NOT_STARTED )
00329                         mCurState = SS_RUNNING;
00330 
00331                 if ( mCurState != SS_RUNNING )
00332                         throw exception ( "AStarSearch::singleStep - Search not in a runnable state" );
00333 
00334                 // We couldn't find a path
00335                 if ( mOpenList.empty () )
00336                 {
00337                         mCurState = SS_PATH_NOT_FOUND;
00338                 }
00339 
00340 
00341                 // Grab a node from the Open list
00342                 AStarNode* curNode = _findBestOpenNode ();
00343 
00344                 // We've reached the goal node so build our path
00345                 if ( curNode->mUserState->isNode ( *mEndState ) )
00346                 {
00347                         mCurState = SS_PATH_FOUND;
00348                         _constructionWaypointList ( curNode );
00349                         _releaseNodes ();
00350                         return true;
00351                 }
00352 
00353                 // Check all the neighbours for better matches and update any we find
00354                 for ( int i = 0; i < curNode->mUserState->getNumNeighbours(); i++ )
00355                 {
00356                         AStarNode* neighbour = curNode->getNeighbour ( i );
00357                         if ( neighbour == NULL )
00358                                 continue;
00359 
00360                         // This neighbour is inpassable
00361                         if ( fabs( neighbour->getTraversalCost () + 1.0f ) < 0.01f )
00362                                 continue;
00363 
00364                         // Determine if the node is on the closed list
00365                         AStarNode* closedNode = _getFromClosed ( neighbour );
00366                         if ( closedNode )
00367                         {
00368                                 // If a node was already on the closed list, then it has been
00369                                 // explored and has had it's neighbours explored. We may
00370                                 // have a better way "in to" the node so consider it again
00371                                 // with the new information
00372                                 float newEntryCost = curNode->getEntryCost() + neighbour->getTraversalCost();
00373                                 if ( newEntryCost < closedNode->getEntryCost() )
00374                                 {
00375                                         // Update the neighbour to use the current node as the entry point
00376                                         neighbour->setSuccessor ( curNode );
00377                                         neighbour->setEntryCost ( newEntryCost );
00378 
00379                                         // We need to re-examine this node
00380                                         _removeFromClosed ( closedNode );
00381                                         _addToOpen ( neighbour );
00382                                 }
00383                                 continue;
00384                         }
00385 
00386                         // The node was set to be explored but we have found a better way in
00387                         // so consider this way into the node
00388                         AStarNode* openNode = _getFromOpen ( neighbour );
00389                         if ( openNode )
00390                         {
00391                                 // If a node was already on the open list, it's estimated cost
00392                                 // to the goal has already been set
00393 
00394                                 float newEntryCost = curNode->getEntryCost() + neighbour->getTraversalCost();
00395                                 if ( newEntryCost < neighbour->getEntryCost() )
00396                                 {
00397                                         // Update the neighbour to use the current node as the entry point
00398                                         neighbour->setSuccessor ( curNode );
00399                                         neighbour->setEntryCost ( newEntryCost );
00400                                         
00401                                         _removeFromOpen ( openNode );
00402                                         _addToOpen ( neighbour );
00403                                 }
00404                                 continue;
00405                         }
00406 
00407                         // The neighbour state hasn't been considered before
00408                         // so set it up and add it to the open list for consideration
00409                         else
00410                         {
00411                                 neighbour->setSuccessor ( curNode );
00412                                 neighbour->setEntryCost ( curNode->getEntryCost() + neighbour->getTraversalCost () );
00413                                 neighbour->setEstimatedGoalCost( neighbour->mUserState->getEstimatedCostToState( *mEndState ));
00414                                 neighbour->createNeighbourNodes ();
00415                                 _addToOpen ( neighbour );
00416                         }
00417                 }
00418 
00419                 _addToClosed ( curNode );
00420 
00421                 return false;
00422         }
00423 
00424         // ----------------------------------------------------------------------------
00425         const WaypointList<UserState>*  getSolutionPath ()
00426         {
00427                 if ( !mSolutionList )
00428                         throw exception ( "AStarSearch::getSolutionPath - No solution exists" );
00429 
00430                 return mSolutionList;
00431         }
00432 
00433 
00434 // Helpers
00435 // ----------------------------------------------------------------------------
00436 private:
00437 
00443         AStarNode* _findBestOpenNode ()
00444         {
00445                 if ( mOpenList.empty () )
00446                         throw exception ( "AStarSearch::_findBestOpenNode - empty open list" );
00447 
00448                 
00449 #ifdef _USE_HEAP_
00450                 AStarNode* bestNode = mOpenList.front();
00451                 pop_heap ( mOpenList.begin(), mOpenList.end(), HeapCompare() );
00452                 mOpenList.pop_back();
00453 #else
00454                 // Search the list backwards until we implement a priority heap
00455                 AStarNode* bestNode = mOpenList.front ();
00456                 float bestCost = bestNode->getTotalCost ();
00457 
00458                 vector<AStarNode*>::iterator bestNodeIter = mOpenList.begin();
00459                 vector<AStarNode*>::iterator iter = mOpenList.begin ();
00460 
00461                 while ( iter != mOpenList.end () )
00462                 {
00463                         if ( (*iter)->getTotalCost() < bestCost )
00464                         {
00465                                 bestNode = (*iter);
00466                                 bestCost = bestNode->getTotalCost();
00467                                 bestNodeIter = iter;
00468                         }
00469 
00470                         iter++;
00471                 }
00472 
00473                 if ( bestNodeIter != mOpenList.end() )
00474                 {
00475                         mOpenList.erase ( bestNodeIter );
00476                 }
00477 #endif
00478 
00479 #if _DEBUG
00480                 //cout << "Best Open Node: (" << bestNode->mUserState->mX << ", " << bestNode->mUserState->mY << ")" << endl;
00481                 cout << "\tEntry Cost: " << bestNode->getEntryCost () << endl;
00482                 cout << "\tGoal Cost: " << bestNode->getEstimatedGoalCost () << endl;
00483 #endif
00484 
00485                 return bestNode;
00486         }
00487 
00488 
00489 
00490         // ----------------------------------------------------------------------------
00491         AStarNode* _getFromClosed ( AStarNode* node )
00492         {
00493                 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00494                 while ( iter != mClosedList.end () )
00495                 {
00496                         if ( *(*iter) == (*node) )
00497                                 return *iter;
00498 
00499                         iter++;
00500                 }
00501 
00502                 return NULL;
00503         }
00504 
00505 
00506         // ----------------------------------------------------------------------------
00507         AStarNode* _getFromOpen ( AStarNode* node )
00508         {
00509                 vector<AStarNode*>::iterator iter = mOpenList.begin ();
00510                 while ( iter != mOpenList.end () )
00511                 {
00512                         if ( *(*iter) == (*node) )
00513                                 return *iter;
00514 
00515                         iter++;
00516                 }
00517 
00518                 return NULL;
00519         }
00520 
00521 
00522         // ----------------------------------------------------------------------------
00523         bool _isNodeClosed ( AStarNode* node )
00524         {
00525                 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00526                 while ( iter != mClosedList.end () )
00527                 {
00528                         if ( *(*iter) == (*node) )
00529                                 return true;
00530 
00531                         iter++;
00532                 }
00533                 return false;
00534         }
00535 
00536         // ----------------------------------------------------------------------------
00537         bool _isNodeOpen ( AStarNode* node )
00538         {
00539                 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00540                 while ( iter != mClosedList.end () )
00541                 {
00542                         if ( *(*iter) == (*node) )
00543                                 return true;
00544 
00545                         iter++;
00546                 }
00547                 return false;
00548         }
00549 
00550         // ----------------------------------------------------------------------------
00551         void _removeFromClosed ( AStarNode* node )
00552         {
00553                 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00554                 while ( iter != mClosedList.end () )
00555                 {
00556                         if ( *(*iter) == (*node) )
00557                         {
00558                                 mClosedList.erase ( iter );
00559                                 return;
00560                         }
00561 
00562                         iter++;
00563                 }
00564 
00565                 throw exception ( "AStarSearch::_removeFromClosed - Node not found" );
00566         }
00567 
00568         // ----------------------------------------------------------------------------
00569         void _removeFromOpen ( AStarNode* node )
00570         {
00571                 vector<AStarNode*>::iterator iter = mOpenList.begin ();
00572                 while ( iter != mOpenList.end () )
00573                 {
00574                         if ( *(*iter) == (*node) )
00575                         {
00576                                 mOpenList.erase ( iter );
00577 #ifdef _USE_HEAP_
00578                                 make_heap ( mOpenList.begin(), mOpenList.end(), HeapCompare() );
00579 #endif
00580                                 return;
00581                         }
00582 
00583                         iter++;
00584                 }
00585 
00586                 throw exception ( "AStarSearch::_removeFromOpen - Node not found" );
00587         }
00588 
00589         // ----------------------------------------------------------------------------
00590         void _addToOpen ( AStarNode* node )
00591         {
00592                 mOpenList.push_back ( node );
00593 #ifdef _USE_HEAP_
00594                 push_heap ( mOpenList.begin(), mOpenList.end(), HeapCompare() );
00595 #endif
00596         }
00597 
00598         // ----------------------------------------------------------------------------
00599         void _addToClosed ( AStarNode* node )
00600         {
00601                 mClosedList.push_back ( node );
00602         }
00603 
00604         // ----------------------------------------------------------------------------
00618         void _constructionWaypointList ( AStarNode* endNode )
00619         {
00620                 mSolutionList = new WaypointList<UserState>();
00621 
00622                 AStarNode* curNode = endNode;
00623                 while ( curNode != NULL )
00624                 {
00625                         mSolutionList->addWaypointToFront ( curNode->mUserState );
00626                         curNode = curNode->mSuccessor;
00627                 }
00628         }
00629 
00630         // ----------------------------------------------------------------------------
00631         void _releaseNodes ()
00632         {
00633                 // All nodes in the search are part of a connected tree with
00634                 // mStartNode as the root
00635                 delete mStartNode;
00636                 mStartNode = NULL;
00637         }
00638 
00639 };
00640 
00641 template<typename UserState> int AStarSearch<UserState>::AStarNode::mNumNodes = 0;
00642 
00643 #endif

Generated on Sun Jun 25 19:23:43 2006 for Valors End by  doxygen 1.4.7