00001 #ifndef __ASTAR_SEARCH_H__
00002 #define __ASTAR_SEARCH_H__
00003
00004 #pragma once
00005
00006
00007
00008 #include "WaypointList.h"
00009
00010 #include <cassert>
00011 #include <vector>
00012 #include <exception>
00013 #include <algorithm>
00014 #include <iostream>
00015
00016 using std::cout;
00017 using std::endl;
00018
00019 using std::vector;
00020 using std::exception;
00021 using std::find;
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050 template <class UserState>
00051 class AStarSearch
00052 {
00053
00054 class AStarNode
00055 {
00056 public:
00057
00058
00059 UserState* mUserState;
00060
00061
00062 AStarNode* mSuccessor;
00063
00064
00065
00066 private:
00067 float mEntryCost;
00068 float mEstimatedGoalCost;
00069
00070
00071
00072 AStarNode* mNeighbours[8];
00073
00074 static int mNumNodes;
00075 unsigned int mID;
00076
00077
00078
00079
00080 public:
00081 AStarNode ( UserState* state )
00082 : mSuccessor ( NULL ), mUserState ( state),
00083 mEntryCost ( 0.0 ), mEstimatedGoalCost ( 0.0 )
00084 {
00085 mID = mNumNodes++;
00086
00087
00088 for ( int i = 0; i < mUserState->getNumNeighbours(); i++ )
00089 mNeighbours[i] = NULL;
00090
00091
00092 }
00093
00094
00095
00096 ~AStarNode ()
00097 {
00098
00099
00100 if ( mUserState == NULL )
00101 return;
00102
00103
00104 for ( int i = 0; i < mUserState->getNumNeighbours(); i++ )
00105 delete mNeighbours[i];
00106 }
00107
00108
00109
00110 void createNeighbourNodes ()
00111 {
00112 for ( int i = 0; i < mUserState->getNumNeighbours(); i++ )
00113 mNeighbours[i] = new AStarNode( mUserState->getNeighbour(i) );
00114 }
00115
00116
00117
00118 bool operator== ( AStarNode& rhs ) const
00119 {
00120 return ( mUserState == rhs.mUserState );
00121 }
00122
00123
00124
00125 bool operator== ( const AStarNode& rhs ) const
00126 {
00127 return ( mUserState == rhs.mUserState );
00128 }
00129
00130
00131
00132
00133 public:
00134 float getTotalCost () const
00135 {
00136 return mEntryCost + mEstimatedGoalCost;
00137 }
00138
00139
00140 float getEntryCost () const
00141 {
00142 return mEntryCost;
00143 }
00144
00145
00146 float getEstimatedGoalCost () const
00147 {
00148 return mEstimatedGoalCost;
00149 }
00150
00151
00152 float getTraversalCost () const
00153 {
00154 return mUserState->getTraversalCost();
00155 }
00156
00157
00158 AStarNode* getNeighbour ( int index )
00159 {
00160 assert ( index < 8 );
00161
00162 return mNeighbours[index];
00163 }
00164
00165
00166
00167
00168 public:
00169 void setEntryCost ( float entryCost )
00170 {
00171 mEntryCost = entryCost;
00172 }
00173
00174
00175 void setEstimatedGoalCost ( float cost )
00176 {
00177 mEstimatedGoalCost = cost;
00178 }
00179
00180
00181 void setSuccessor ( AStarNode* parent )
00182 {
00183 mSuccessor = parent;
00184 }
00185 };
00186
00187
00188
00189
00190
00191 public:
00192 enum SearchState
00193 {
00194 SS_NOT_STARTED,
00195 SS_RUNNING,
00196 SS_PATH_NOT_FOUND,
00197 SS_PATH_FOUND
00198 };
00199
00200 private:
00201 SearchState mCurState;
00202 AStarNode* mStartNode;
00203 UserState* mEndState;
00204
00205 AStarNode* mGoalNode;
00206
00207 vector<AStarNode*> mOpenList;
00208 vector<AStarNode*> mClosedList;
00209
00210 WaypointList<UserState>* mSolutionList;
00211
00212
00213
00214 public:
00215 class HeapCompare
00216 {
00217 public:
00218
00219 bool operator() ( const AStarNode* x, const AStarNode* y ) const
00220 {
00221
00222
00223
00224
00225
00226 return x->getTotalCost() > y->getTotalCost();
00227 }
00228 };
00229
00230
00231
00232 public:
00233 AStarSearch(void)
00234 {
00235 mCurState = SS_NOT_STARTED;
00236
00237 mGoalNode = NULL;
00238 mStartNode = NULL;
00239
00240 mSolutionList = NULL;
00241
00242 mOpenList.clear ();
00243 mClosedList.clear ();
00244 }
00245
00246
00247 ~AStarSearch(void)
00248 {
00249 _releaseNodes ();
00250 _releaseSolution ();
00251 }
00252
00253
00254
00255 void _releaseSolution ()
00256 {
00257 delete mSolutionList;
00258 mSolutionList = NULL;
00259 }
00260
00261
00262
00263
00264 public:
00265 void setStartState ( UserState* startState, UserState* endState )
00266 {
00267 _releaseNodes ();
00268 _releaseSolution ();
00269
00270 mStartNode = NULL;
00271 mEndState = NULL;
00272
00273 if ( fabs(endState->getTraversalCost() + 1.0f) < 0.01f )
00274 {
00275 mCurState = SS_PATH_NOT_FOUND;
00276 return;
00277 }
00278
00279 mStartNode = new AStarNode ( startState );
00280 mStartNode->createNeighbourNodes ();
00281 mStartNode->setEstimatedGoalCost ( mStartNode->mUserState->getEstimatedCostToState ( *endState ) );
00282
00283 mEndState = endState;
00284
00285 mOpenList.push_back ( mStartNode );
00286 }
00287
00288
00289
00290
00291 public:
00292 SearchState getSearchState ()
00293 {
00294 return mCurState;
00295 }
00296
00300 bool isRunnable ()
00301 {
00302 return (mCurState == SS_RUNNING || mCurState == SS_NOT_STARTED );
00303 }
00304
00308 bool isSolved ()
00309 {
00310 return (mCurState == SS_PATH_FOUND );
00311 }
00312
00313
00317 bool isUnsolvable ()
00318 {
00319 return (mCurState == SS_PATH_NOT_FOUND );
00320 }
00321
00322
00326 bool advanceSearch ()
00327 {
00328 if ( mCurState == SS_NOT_STARTED )
00329 mCurState = SS_RUNNING;
00330
00331 if ( mCurState != SS_RUNNING )
00332 throw exception ( "AStarSearch::singleStep - Search not in a runnable state" );
00333
00334
00335 if ( mOpenList.empty () )
00336 {
00337 mCurState = SS_PATH_NOT_FOUND;
00338 }
00339
00340
00341
00342 AStarNode* curNode = _findBestOpenNode ();
00343
00344
00345 if ( curNode->mUserState->isNode ( *mEndState ) )
00346 {
00347 mCurState = SS_PATH_FOUND;
00348 _constructionWaypointList ( curNode );
00349 _releaseNodes ();
00350 return true;
00351 }
00352
00353
00354 for ( int i = 0; i < curNode->mUserState->getNumNeighbours(); i++ )
00355 {
00356 AStarNode* neighbour = curNode->getNeighbour ( i );
00357 if ( neighbour == NULL )
00358 continue;
00359
00360
00361 if ( fabs( neighbour->getTraversalCost () + 1.0f ) < 0.01f )
00362 continue;
00363
00364
00365 AStarNode* closedNode = _getFromClosed ( neighbour );
00366 if ( closedNode )
00367 {
00368
00369
00370
00371
00372 float newEntryCost = curNode->getEntryCost() + neighbour->getTraversalCost();
00373 if ( newEntryCost < closedNode->getEntryCost() )
00374 {
00375
00376 neighbour->setSuccessor ( curNode );
00377 neighbour->setEntryCost ( newEntryCost );
00378
00379
00380 _removeFromClosed ( closedNode );
00381 _addToOpen ( neighbour );
00382 }
00383 continue;
00384 }
00385
00386
00387
00388 AStarNode* openNode = _getFromOpen ( neighbour );
00389 if ( openNode )
00390 {
00391
00392
00393
00394 float newEntryCost = curNode->getEntryCost() + neighbour->getTraversalCost();
00395 if ( newEntryCost < neighbour->getEntryCost() )
00396 {
00397
00398 neighbour->setSuccessor ( curNode );
00399 neighbour->setEntryCost ( newEntryCost );
00400
00401 _removeFromOpen ( openNode );
00402 _addToOpen ( neighbour );
00403 }
00404 continue;
00405 }
00406
00407
00408
00409 else
00410 {
00411 neighbour->setSuccessor ( curNode );
00412 neighbour->setEntryCost ( curNode->getEntryCost() + neighbour->getTraversalCost () );
00413 neighbour->setEstimatedGoalCost( neighbour->mUserState->getEstimatedCostToState( *mEndState ));
00414 neighbour->createNeighbourNodes ();
00415 _addToOpen ( neighbour );
00416 }
00417 }
00418
00419 _addToClosed ( curNode );
00420
00421 return false;
00422 }
00423
00424
00425 const WaypointList<UserState>* getSolutionPath ()
00426 {
00427 if ( !mSolutionList )
00428 throw exception ( "AStarSearch::getSolutionPath - No solution exists" );
00429
00430 return mSolutionList;
00431 }
00432
00433
00434
00435
00436 private:
00437
00443 AStarNode* _findBestOpenNode ()
00444 {
00445 if ( mOpenList.empty () )
00446 throw exception ( "AStarSearch::_findBestOpenNode - empty open list" );
00447
00448
00449 #ifdef _USE_HEAP_
00450 AStarNode* bestNode = mOpenList.front();
00451 pop_heap ( mOpenList.begin(), mOpenList.end(), HeapCompare() );
00452 mOpenList.pop_back();
00453 #else
00454
00455 AStarNode* bestNode = mOpenList.front ();
00456 float bestCost = bestNode->getTotalCost ();
00457
00458 vector<AStarNode*>::iterator bestNodeIter = mOpenList.begin();
00459 vector<AStarNode*>::iterator iter = mOpenList.begin ();
00460
00461 while ( iter != mOpenList.end () )
00462 {
00463 if ( (*iter)->getTotalCost() < bestCost )
00464 {
00465 bestNode = (*iter);
00466 bestCost = bestNode->getTotalCost();
00467 bestNodeIter = iter;
00468 }
00469
00470 iter++;
00471 }
00472
00473 if ( bestNodeIter != mOpenList.end() )
00474 {
00475 mOpenList.erase ( bestNodeIter );
00476 }
00477 #endif
00478
00479 #if _DEBUG
00480
00481 cout << "\tEntry Cost: " << bestNode->getEntryCost () << endl;
00482 cout << "\tGoal Cost: " << bestNode->getEstimatedGoalCost () << endl;
00483 #endif
00484
00485 return bestNode;
00486 }
00487
00488
00489
00490
00491 AStarNode* _getFromClosed ( AStarNode* node )
00492 {
00493 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00494 while ( iter != mClosedList.end () )
00495 {
00496 if ( *(*iter) == (*node) )
00497 return *iter;
00498
00499 iter++;
00500 }
00501
00502 return NULL;
00503 }
00504
00505
00506
00507 AStarNode* _getFromOpen ( AStarNode* node )
00508 {
00509 vector<AStarNode*>::iterator iter = mOpenList.begin ();
00510 while ( iter != mOpenList.end () )
00511 {
00512 if ( *(*iter) == (*node) )
00513 return *iter;
00514
00515 iter++;
00516 }
00517
00518 return NULL;
00519 }
00520
00521
00522
00523 bool _isNodeClosed ( AStarNode* node )
00524 {
00525 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00526 while ( iter != mClosedList.end () )
00527 {
00528 if ( *(*iter) == (*node) )
00529 return true;
00530
00531 iter++;
00532 }
00533 return false;
00534 }
00535
00536
00537 bool _isNodeOpen ( AStarNode* node )
00538 {
00539 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00540 while ( iter != mClosedList.end () )
00541 {
00542 if ( *(*iter) == (*node) )
00543 return true;
00544
00545 iter++;
00546 }
00547 return false;
00548 }
00549
00550
00551 void _removeFromClosed ( AStarNode* node )
00552 {
00553 vector<AStarNode*>::iterator iter = mClosedList.begin ();
00554 while ( iter != mClosedList.end () )
00555 {
00556 if ( *(*iter) == (*node) )
00557 {
00558 mClosedList.erase ( iter );
00559 return;
00560 }
00561
00562 iter++;
00563 }
00564
00565 throw exception ( "AStarSearch::_removeFromClosed - Node not found" );
00566 }
00567
00568
00569 void _removeFromOpen ( AStarNode* node )
00570 {
00571 vector<AStarNode*>::iterator iter = mOpenList.begin ();
00572 while ( iter != mOpenList.end () )
00573 {
00574 if ( *(*iter) == (*node) )
00575 {
00576 mOpenList.erase ( iter );
00577 #ifdef _USE_HEAP_
00578 make_heap ( mOpenList.begin(), mOpenList.end(), HeapCompare() );
00579 #endif
00580 return;
00581 }
00582
00583 iter++;
00584 }
00585
00586 throw exception ( "AStarSearch::_removeFromOpen - Node not found" );
00587 }
00588
00589
00590 void _addToOpen ( AStarNode* node )
00591 {
00592 mOpenList.push_back ( node );
00593 #ifdef _USE_HEAP_
00594 push_heap ( mOpenList.begin(), mOpenList.end(), HeapCompare() );
00595 #endif
00596 }
00597
00598
00599 void _addToClosed ( AStarNode* node )
00600 {
00601 mClosedList.push_back ( node );
00602 }
00603
00604
00618 void _constructionWaypointList ( AStarNode* endNode )
00619 {
00620 mSolutionList = new WaypointList<UserState>();
00621
00622 AStarNode* curNode = endNode;
00623 while ( curNode != NULL )
00624 {
00625 mSolutionList->addWaypointToFront ( curNode->mUserState );
00626 curNode = curNode->mSuccessor;
00627 }
00628 }
00629
00630
00631 void _releaseNodes ()
00632 {
00633
00634
00635 delete mStartNode;
00636 mStartNode = NULL;
00637 }
00638
00639 };
00640
00641 template<typename UserState> int AStarSearch<UserState>::AStarNode::mNumNodes = 0;
00642
00643 #endif