Refactored expressions thanks to intrusive pointers.

This commit is contained in:
Patrick Lühne 2016-09-04 18:48:46 +02:00
parent 9afabacde3
commit 7aa20a5820
15 changed files with 56 additions and 111 deletions

View File

@ -34,17 +34,15 @@ class At: public ExpressionCRTP<At>
size_t timePoint() const; size_t timePoint() const;
void setArgument(const Expression *argument); void setArgument(ExpressionPointer argument);
void setArgument(ExpressionPointer &&argument); ExpressionPointer argument() const;
const Expression *argument() const;
ExpressionPointer normalize() override; ExpressionPointer normalize() override;
protected: protected:
size_t m_timePoint; size_t m_timePoint;
const Expression *m_argument; ExpressionPointer m_argument;
ExpressionPointer m_argumentStorage;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -28,17 +28,13 @@ class Binary: public ExpressionCRTP<Derived>
ExpressionContext &expressionContext, ExpressionParser parseExpression); ExpressionContext &expressionContext, ExpressionParser parseExpression);
public: public:
template<size_t i> void setArgument(size_t i, ExpressionPointer argument);
void setArgument(const Expression *argument); const std::array<ExpressionPointer, 2> &arguments() const;
template<size_t i>
void setArgument(ExpressionPointer &&argument);
const std::array<const Expression *, 2> &arguments() const;
ExpressionPointer normalize() override; ExpressionPointer normalize() override;
protected: protected:
std::array<const Expression *, 2> m_arguments; std::array<ExpressionPointer, 2> m_arguments;
std::array<ExpressionPointer, 2> m_argumentStorage;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -63,8 +59,8 @@ boost::intrusive_ptr<Derived> Binary<Derived>::parse(Context &context,
// Assume that expression identifier (imply, exists, etc.) is already parsed // Assume that expression identifier (imply, exists, etc.) is already parsed
// Parse arguments of the expression // Parse arguments of the expression
expression->Binary<Derived>::setArgument<0>(parseExpression(context, expressionContext)); expression->Binary<Derived>::setArgument(0, parseExpression(context, expressionContext));
expression->Binary<Derived>::setArgument<1>(parseExpression(context, expressionContext)); expression->Binary<Derived>::setArgument(1, parseExpression(context, expressionContext));
parser.expect<std::string>(")"); parser.expect<std::string>(")");
@ -74,31 +70,17 @@ boost::intrusive_ptr<Derived> Binary<Derived>::parse(Context &context,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
template<size_t i> void Binary<Derived>::setArgument(size_t i, ExpressionPointer expression)
void Binary<Derived>::setArgument(const Expression *expression)
{ {
static_assert(i <= 2, "Index out of range"); BOOST_ASSERT_MSG(i <= 2, "Index out of range");
m_argumentStorage[i] = nullptr;
m_arguments[i] = expression; m_arguments[i] = expression;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
template<size_t i> const std::array<ExpressionPointer, 2> &Binary<Derived>::arguments() const
void Binary<Derived>::setArgument(ExpressionPointer &&expression)
{
static_assert(i <= 2, "Index out of range");
m_argumentStorage[i] = std::move(expression);
m_arguments[i] = m_argumentStorage[i].get();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
const std::array<const Expression *, 2> &Binary<Derived>::arguments() const
{ {
return m_arguments; return m_arguments;
} }
@ -108,18 +90,17 @@ const std::array<const Expression *, 2> &Binary<Derived>::arguments() const
template<class Derived> template<class Derived>
inline ExpressionPointer Binary<Derived>::normalize() inline ExpressionPointer Binary<Derived>::normalize()
{ {
for (size_t i = 0; i < m_argumentStorage.size(); i++) for (size_t i = 0; i < m_arguments.size(); i++)
{ {
BOOST_ASSERT(m_argumentStorage[i]); BOOST_ASSERT(m_arguments[i]);
auto normalizedArgument = m_argumentStorage[i]->normalize(); auto normalizedArgument = m_arguments[i]->normalize();
// Replace argument if changed by normalization // Replace argument if changed by normalization
if (!normalizedArgument) if (!normalizedArgument)
continue; continue;
m_argumentStorage[i] = std::move(normalizedArgument); m_arguments[i] = std::move(normalizedArgument);
m_arguments[i] = m_argumentStorage[i].get();
} }
return nullptr; return nullptr;

View File

@ -28,15 +28,13 @@ class NAry: public ExpressionCRTP<Derived>
ExpressionContext &expressionContext, ExpressionParser parseExpression); ExpressionContext &expressionContext, ExpressionParser parseExpression);
public: public:
void addArgument(const Expression *argument); void addArgument(ExpressionPointer argument);
void addArgument(ExpressionPointer &&argument); const Expressions &arguments() const;
const std::vector<const Expression *> &arguments() const;
ExpressionPointer normalize() override; ExpressionPointer normalize() override;
protected: protected:
std::vector<const Expression *> m_arguments; Expressions m_arguments;
Expressions m_argumentStorage;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -81,7 +79,7 @@ boost::intrusive_ptr<Derived> NAry<Derived>::parse(Context &context,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
void NAry<Derived>::addArgument(const Expression *argument) void NAry<Derived>::addArgument(ExpressionPointer argument)
{ {
if (!argument) if (!argument)
return; return;
@ -92,19 +90,7 @@ void NAry<Derived>::addArgument(const Expression *argument)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
void NAry<Derived>::addArgument(ExpressionPointer &&argument) const Expressions &NAry<Derived>::arguments() const
{
if (!argument)
return;
m_argumentStorage.emplace_back(std::move(argument));
m_arguments.emplace_back(m_argumentStorage.back().get());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
const std::vector<const Expression *> &NAry<Derived>::arguments() const
{ {
return m_arguments; return m_arguments;
} }
@ -114,18 +100,17 @@ const std::vector<const Expression *> &NAry<Derived>::arguments() const
template<class Derived> template<class Derived>
inline ExpressionPointer NAry<Derived>::normalize() inline ExpressionPointer NAry<Derived>::normalize()
{ {
for (size_t i = 0; i < m_argumentStorage.size(); i++) for (size_t i = 0; i < m_arguments.size(); i++)
{ {
BOOST_ASSERT(m_argumentStorage[i]); BOOST_ASSERT(m_arguments[i]);
auto normalizedArgument = m_argumentStorage[i]->normalize(); auto normalizedArgument = m_arguments[i]->normalize();
// Replace argument if changed by normalization // Replace argument if changed by normalization
if (!normalizedArgument) if (!normalizedArgument)
continue; continue;
m_argumentStorage[i] = std::move(normalizedArgument); m_arguments[i] = std::move(normalizedArgument);
m_arguments[i] = m_argumentStorage[i].get();
} }
return nullptr; return nullptr;

View File

@ -29,15 +29,13 @@ class Not: public ExpressionCRTP<Not>
public: public:
Not(); Not();
void setArgument(const Expression *argument); void setArgument(ExpressionPointer argument);
void setArgument(ExpressionPointer &&argument); ExpressionPointer argument() const;
const Expression *argument() const;
ExpressionPointer normalize() override; ExpressionPointer normalize() override;
protected: protected:
const Expression *m_argument; ExpressionPointer m_argument;
ExpressionPointer m_argumentStorage;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -26,7 +26,7 @@ class Predicate: public ExpressionCRTP<Predicate>
public: public:
const std::string &name() const; const std::string &name() const;
const std::vector<ExpressionPointer> &arguments() const; const Expressions &arguments() const;
bool isDeclared() const; bool isDeclared() const;
@ -40,7 +40,7 @@ class Predicate: public ExpressionCRTP<Predicate>
bool m_isDeclared; bool m_isDeclared;
std::string m_name; std::string m_name;
std::vector<ExpressionPointer> m_arguments; Expressions m_arguments;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -33,7 +33,7 @@ class PrimitiveType: public ExpressionCRTP<PrimitiveType>
PrimitiveType(std::string name); PrimitiveType(std::string name);
const std::string &name() const; const std::string &name() const;
const std::vector<PrimitiveTypePointer> &parentTypes() const; const PrimitiveTypes &parentTypes() const;
ExpressionPointer normalize() override; ExpressionPointer normalize() override;
@ -45,7 +45,7 @@ class PrimitiveType: public ExpressionCRTP<PrimitiveType>
std::string m_name; std::string m_name;
std::vector<PrimitiveTypePointer> m_parentTypes; PrimitiveTypes m_parentTypes;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -16,7 +16,7 @@ namespace expressions
// //
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer parseExistingPrimitiveType(Context &context, PrimitiveTypePointer parseExistingPrimitiveType(Context &context,
ExpressionContext &expressionContext); ExpressionContext &expressionContext);
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -261,7 +261,7 @@ void TranslatorASP::translateActions() const
const auto &andExpression = dynamic_cast<const expressions::And &>(precondition); const auto &andExpression = dynamic_cast<const expressions::And &>(precondition);
std::for_each(andExpression.arguments().cbegin(), andExpression.arguments().cend(), std::for_each(andExpression.arguments().cbegin(), andExpression.arguments().cend(),
[&](const auto *argument) [&](const auto argument)
{ {
translateLiteral("precondition", *argument); translateLiteral("precondition", *argument);
}); });
@ -287,7 +287,7 @@ void TranslatorASP::translateActions() const
const auto &andExpression = dynamic_cast<const expressions::And &>(effect); const auto &andExpression = dynamic_cast<const expressions::And &>(effect);
std::for_each(andExpression.arguments().cbegin(), andExpression.arguments().cend(), std::for_each(andExpression.arguments().cbegin(), andExpression.arguments().cend(),
[&](const auto *argument) [&](const auto argument)
{ {
translateLiteral("postcondition", *argument, true); translateLiteral("postcondition", *argument, true);
}); });
@ -557,7 +557,7 @@ void TranslatorASP::translateGoal() const
const auto &andExpression = dynamic_cast<const expressions::And &>(goal); const auto &andExpression = dynamic_cast<const expressions::And &>(goal);
std::for_each(andExpression.arguments().cbegin(), andExpression.arguments().cend(), std::for_each(andExpression.arguments().cbegin(), andExpression.arguments().cend(),
[&](const auto *argument) [&](const auto argument)
{ {
m_outputStream << std::endl << utils::RuleName("goal") << "("; m_outputStream << std::endl << utils::RuleName("goal") << "(";

View File

@ -20,23 +20,14 @@ At::At()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void At::setArgument(const Expression *argument) void At::setArgument(ExpressionPointer argument)
{ {
m_argumentStorage = nullptr;
m_argument = argument; m_argument = argument;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void At::setArgument(ExpressionPointer &&argument) ExpressionPointer At::argument() const
{
m_argumentStorage = std::move(argument);
m_argument = m_argumentStorage.get();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
const Expression *At::argument() const
{ {
return m_argument; return m_argument;
} }
@ -45,9 +36,9 @@ const Expression *At::argument() const
ExpressionPointer At::normalize() ExpressionPointer At::normalize()
{ {
BOOST_ASSERT(m_argumentStorage); BOOST_ASSERT(m_argument);
auto normalizedArgument = m_argumentStorage->normalize(); auto normalizedArgument = m_argument->normalize();
// Replace argument if changed by normalization // Replace argument if changed by normalization
if (normalizedArgument) if (normalizedArgument)

View File

@ -23,15 +23,15 @@ const std::string Imply::Identifier = "imply";
ExpressionPointer Imply::normalize() ExpressionPointer Imply::normalize()
{ {
BOOST_ASSERT(m_argumentStorage[0]); BOOST_ASSERT(m_arguments[0]);
BOOST_ASSERT(m_argumentStorage[1]); BOOST_ASSERT(m_arguments[1]);
auto notArgument0 = NotPointer(new Not); auto notArgument0 = NotPointer(new Not);
notArgument0->setArgument(std::move(m_argumentStorage[0])); notArgument0->setArgument(std::move(m_arguments[0]));
auto orExpression = OrPointer(new Or); auto orExpression = OrPointer(new Or);
orExpression->addArgument(std::move(notArgument0)); orExpression->addArgument(std::move(notArgument0));
orExpression->addArgument(std::move(m_argumentStorage[1])); orExpression->addArgument(std::move(m_arguments[1]));
auto normalizedOrExpression = orExpression->normalize(); auto normalizedOrExpression = orExpression->normalize();

View File

@ -20,23 +20,14 @@ Not::Not()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void Not::setArgument(const Expression *argument) void Not::setArgument(ExpressionPointer argument)
{ {
m_argumentStorage = nullptr;
m_argument = argument; m_argument = argument;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void Not::setArgument(ExpressionPointer &&argument) ExpressionPointer Not::argument() const
{
m_argumentStorage = std::move(argument);
m_argument = m_argumentStorage.get();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
const Expression *Not::argument() const
{ {
return m_argument; return m_argument;
} }
@ -45,14 +36,14 @@ const Expression *Not::argument() const
ExpressionPointer Not::normalize() ExpressionPointer Not::normalize()
{ {
BOOST_ASSERT(m_argumentStorage); BOOST_ASSERT(m_argument);
// Remove double negations immediately // Remove double negations immediately
if (m_argumentStorage->expressionType() == Expression::Type::Not) if (m_argument->expressionType() == Expression::Type::Not)
{ {
auto &argument = dynamic_cast<Not &>(*m_argumentStorage); auto &argument = dynamic_cast<Not &>(*m_argument);
auto normalized = std::move(argument.m_argumentStorage); auto normalized = std::move(argument.m_argument);
auto normalizedInner = normalized->normalize(); auto normalizedInner = normalized->normalize();
if (normalizedInner) if (normalizedInner)
@ -61,7 +52,7 @@ ExpressionPointer Not::normalize()
return normalized; return normalized;
} }
auto normalizedArgument = m_argumentStorage->normalize(); auto normalizedArgument = m_argument->normalize();
// Replace argument if changed by normalization // Replace argument if changed by normalization
if (normalizedArgument) if (normalizedArgument)

View File

@ -163,7 +163,7 @@ const std::string &Predicate::name() const
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
const std::vector<ExpressionPointer> &Predicate::arguments() const const Expressions &Predicate::arguments() const
{ {
return m_arguments; return m_arguments;
} }

View File

@ -162,7 +162,7 @@ const std::string &PrimitiveType::name() const
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
const std::vector<PrimitiveTypePointer> &PrimitiveType::parentTypes() const const PrimitiveTypes &PrimitiveType::parentTypes() const
{ {
return m_parentTypes; return m_parentTypes;
} }

View File

@ -17,7 +17,8 @@ namespace expressions
// //
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer parseExistingPrimitiveType(Context &context, ExpressionContext &expressionContext) PrimitiveTypePointer parseExistingPrimitiveType(Context &context,
ExpressionContext &expressionContext)
{ {
return PrimitiveType::parseAndFind(context, expressionContext.domain); return PrimitiveType::parseAndFind(context, expressionContext.domain);
} }

View File

@ -17,8 +17,8 @@ TEST(PDDLNormalizationTests, Implication)
auto d2 = expressions::DummyPointer(new expressions::Dummy); auto d2 = expressions::DummyPointer(new expressions::Dummy);
const auto d2p = d2.get(); const auto d2p = d2.get();
i->setArgument<0>(d1); i->setArgument(0, d1);
i->setArgument<1>(d2); i->setArgument(1, d2);
auto normalized = i->normalize(); auto normalized = i->normalize();