15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
47 State(
const arma::colvec& data) : data(data)
51 arma::colvec&
Data() {
return data; }
64 double Angle()
const {
return data[2]; }
66 double&
Angle() {
return data[2]; }
74 const arma::colvec&
Encode()
const {
return data; }
109 const double massCart = 1.0,
110 const double massPole = 0.1,
111 const double length = 0.5,
112 const double forceMag = 10.0,
113 const double tau = 0.02,
114 const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
115 const double xThreshold = 2.4,
116 const double doneReward = 0.0) :
120 totalMass(massCart + massPole),
122 poleMassLength(massPole * length),
125 thetaThresholdRadians(thetaThresholdRadians),
126 xThreshold(xThreshold),
127 doneReward(doneReward)
141 State& nextState)
const
144 double force = action ? forceMag : -forceMag;
145 double cosTheta = std::cos(state.
Angle());
146 double sinTheta = std::sin(state.
Angle());
149 double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
150 (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
151 double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
184 return Sample(state, action, nextState);
194 return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
205 return std::abs(state.
Position()) > xThreshold ||
206 std::abs(state.
Angle()) > thetaThresholdRadians;
226 double poleMassLength;
235 double thetaThresholdRadians;
State(const arma::colvec &data)
Construct a state instance from given data.
double & Velocity()
Modify the velocity.
State()
Construct a state instance.
bool IsTerminal(const State &state) const
Whether given state is a terminal state.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Action
Implementation of action of Cart Pole.
State InitialSample() const
Initial state representation is randomly generated within [-0.05, 0.05].
double Sample(const State &state, const Action &action) const
Dynamics of Cart Pole.
double Velocity() const
Get the velocity.
Implementation of the state of Cart Pole.
CartPole(const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Cart Pole instance using the given constants.
const arma::colvec & Encode() const
Encode the state to a column vector.
double & Angle()
Modify the angle.
double & Position()
Modify the position.
double Sample(const State &state, const Action &action, State &nextState) const
Dynamics of Cart Pole instance.
double & AngularVelocity()
Modify the angular velocity.
arma::colvec & Data()
Modify the internal representation of the state.
static constexpr size_t dimension
Dimension of the encoded state.
Implementation of Cart Pole task.
double AngularVelocity() const
Get the angular velocity.
double Angle() const
Get the angle.
double Position() const
Get the position.