diff --git a/include/plasp/pddl/Expression.h b/include/plasp/pddl/Expression.h index 52d0919..8eeaa22 100644 --- a/include/plasp/pddl/Expression.h +++ b/include/plasp/pddl/Expression.h @@ -120,11 +120,14 @@ class Expression virtual Type expressionType() const = 0; + virtual ExpressionPointer copy(); + ExpressionPointer normalized(); virtual ExpressionPointer reduced(); virtual ExpressionPointer negationNormalized(); virtual ExpressionPointer prenex(Expression::Type lastQuantifierType = Expression::Type::Exists); virtual ExpressionPointer simplified(); + virtual ExpressionPointer disjunctionNormalized(); ExpressionPointer negated(); virtual void print(std::ostream &ostream) const = 0; diff --git a/include/plasp/pddl/expressions/And.h b/include/plasp/pddl/expressions/And.h index 358e80c..505a5f0 100644 --- a/include/plasp/pddl/expressions/And.h +++ b/include/plasp/pddl/expressions/And.h @@ -22,6 +22,9 @@ class And: public NAry static const Expression::Type ExpressionType = Expression::Type::And; static const std::string Identifier; + + public: + ExpressionPointer disjunctionNormalized() override; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/plasp/pddl/expressions/At.h b/include/plasp/pddl/expressions/At.h index 3996f1c..1983ba8 100644 --- a/include/plasp/pddl/expressions/At.h +++ b/include/plasp/pddl/expressions/At.h @@ -32,6 +32,8 @@ class At: public ExpressionCRTP public: At(); + ExpressionPointer copy() override; + size_t timePoint() const; void setArgument(ExpressionPointer argument); @@ -41,6 +43,7 @@ class At: public ExpressionCRTP ExpressionPointer negationNormalized() override; ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer simplified() override; + ExpressionPointer disjunctionNormalized() override; void print(std::ostream &ostream) const override; diff --git a/include/plasp/pddl/expressions/Binary.h b/include/plasp/pddl/expressions/Binary.h index 5483518..fc6b8a2 100644 --- a/include/plasp/pddl/expressions/Binary.h +++ b/include/plasp/pddl/expressions/Binary.h @@ -29,12 +29,16 @@ class Binary: public ExpressionCRTP ExpressionContext &expressionContext, ExpressionParser parseExpression); public: + ExpressionPointer copy() override; + void setArgument(size_t i, ExpressionPointer argument); const std::array &arguments() const; ExpressionPointer reduced() override; ExpressionPointer negationNormalized() override; ExpressionPointer prenex(Expression::Type lastExpressionType) override; + ExpressionPointer simplified() override; + ExpressionPointer disjunctionNormalized() override; void print(std::ostream &ostream) const override; @@ -73,6 +77,19 @@ boost::intrusive_ptr Binary::parse(Context &context, //////////////////////////////////////////////////////////////////////////////////////////////////// +template +ExpressionPointer Binary::copy() +{ + auto result = new Derived; + + for (size_t i = 0; i < m_arguments.size(); i++) + result->m_arguments[i] = m_arguments[i]->copy(); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template void Binary::setArgument(size_t i, ExpressionPointer expression) { @@ -130,6 +147,36 @@ inline ExpressionPointer Binary::prenex(Expression::Type) //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline ExpressionPointer Binary::simplified() +{ + for (size_t i = 0; i < m_arguments.size(); i++) + { + BOOST_ASSERT(m_arguments[i]); + + m_arguments[i] = m_arguments[i]->simplified(); + } + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline ExpressionPointer Binary::disjunctionNormalized() +{ + for (size_t i = 0; i < m_arguments.size(); i++) + { + BOOST_ASSERT(m_arguments[i]); + + m_arguments[i] = m_arguments[i]->disjunctionNormalized(); + } + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline void Binary::print(std::ostream &ostream) const { diff --git a/include/plasp/pddl/expressions/NAry.h b/include/plasp/pddl/expressions/NAry.h index 2c289e1..db6420a 100644 --- a/include/plasp/pddl/expressions/NAry.h +++ b/include/plasp/pddl/expressions/NAry.h @@ -29,6 +29,8 @@ class NAry: public ExpressionCRTP ExpressionContext &expressionContext, ExpressionParser parseExpression); public: + ExpressionPointer copy() override; + void setArgument(size_t i, ExpressionPointer argument); void addArgument(ExpressionPointer argument); Expressions &arguments(); @@ -38,6 +40,7 @@ class NAry: public ExpressionCRTP ExpressionPointer negationNormalized() override; ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer simplified() override; + ExpressionPointer disjunctionNormalized() override; void print(std::ostream &ostream) const override; @@ -85,6 +88,21 @@ boost::intrusive_ptr NAry::parse(Context &context, //////////////////////////////////////////////////////////////////////////////////////////////////// +template +ExpressionPointer NAry::copy() +{ + auto result = new Derived; + + result->m_arguments.resize(m_arguments.size()); + + for (size_t i = 0; i < m_arguments.size(); i++) + result->m_arguments[i] = m_arguments[i]->copy(); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template void NAry::setArgument(size_t i, ExpressionPointer expression) { @@ -273,6 +291,21 @@ inline ExpressionPointer NAry::simplified() //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline ExpressionPointer NAry::disjunctionNormalized() +{ + for (size_t i = 0; i < m_arguments.size(); i++) + { + BOOST_ASSERT(m_arguments[i]); + + m_arguments[i] = m_arguments[i]->disjunctionNormalized(); + } + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline void NAry::print(std::ostream &ostream) const { diff --git a/include/plasp/pddl/expressions/Not.h b/include/plasp/pddl/expressions/Not.h index 3ef2827..e3785f7 100644 --- a/include/plasp/pddl/expressions/Not.h +++ b/include/plasp/pddl/expressions/Not.h @@ -29,6 +29,8 @@ class Not: public ExpressionCRTP public: Not(); + ExpressionPointer copy() override; + void setArgument(ExpressionPointer argument); ExpressionPointer argument() const; @@ -36,6 +38,7 @@ class Not: public ExpressionCRTP ExpressionPointer negationNormalized() override; ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer simplified() override; + ExpressionPointer disjunctionNormalized() override; void print(std::ostream &ostream) const override; diff --git a/include/plasp/pddl/expressions/Quantified.h b/include/plasp/pddl/expressions/Quantified.h index 761744b..192e420 100644 --- a/include/plasp/pddl/expressions/Quantified.h +++ b/include/plasp/pddl/expressions/Quantified.h @@ -49,10 +49,13 @@ class QuantifiedCRTP: public Quantified return Derived::ExpressionType; } + ExpressionPointer copy() override; + ExpressionPointer reduced() override; ExpressionPointer negationNormalized() override; ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer simplified() override; + ExpressionPointer disjunctionNormalized() override; void print(std::ostream &ostream) const override; }; @@ -98,6 +101,18 @@ boost::intrusive_ptr QuantifiedCRTP::parse(Context &context, //////////////////////////////////////////////////////////////////////////////////////////////////// +template +ExpressionPointer QuantifiedCRTP::copy() +{ + auto result = new Derived; + + result->m_argument = m_argument->copy(); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline void Quantified::setArgument(ExpressionPointer expression) { m_argument = expression; @@ -190,6 +205,18 @@ inline ExpressionPointer QuantifiedCRTP::simplified() //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline ExpressionPointer QuantifiedCRTP::disjunctionNormalized() +{ + BOOST_ASSERT(m_argument); + + m_argument = m_argument->disjunctionNormalized(); + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline void QuantifiedCRTP::print(std::ostream &ostream) const { diff --git a/src/plasp/pddl/Expression.cpp b/src/plasp/pddl/Expression.cpp index 2ffc8e5..7a5fd57 100644 --- a/src/plasp/pddl/Expression.cpp +++ b/src/plasp/pddl/Expression.cpp @@ -26,9 +26,16 @@ namespace pddl // //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Expression::copy() +{ + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + ExpressionPointer Expression::normalized() { - return reduced()->negationNormalized()->prenex()->simplified(); + return reduced()->negationNormalized()->prenex()->simplified()->disjunctionNormalized()->simplified(); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -97,6 +104,13 @@ ExpressionPointer Expression::simplified() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Expression::disjunctionNormalized() +{ + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + ExpressionPointer Expression::negated() { if (expressionType() == Type::Not) diff --git a/src/plasp/pddl/expressions/And.cpp b/src/plasp/pddl/expressions/And.cpp index 912b966..a0a83ce 100644 --- a/src/plasp/pddl/expressions/And.cpp +++ b/src/plasp/pddl/expressions/And.cpp @@ -3,6 +3,8 @@ #include #include +#include + namespace plasp { namespace pddl @@ -20,6 +22,51 @@ const std::string And::Identifier = "and"; //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer And::disjunctionNormalized() +{ + for (size_t i = 0; i < m_arguments.size(); i++) + { + BOOST_ASSERT(m_arguments[i]); + + m_arguments[i] = m_arguments[i]->disjunctionNormalized(); + } + + const auto match = std::find_if(m_arguments.begin(), m_arguments.end(), + [](const auto &argument) + { + return argument->expressionType() == Expression::Type::Or; + }); + + if (match == m_arguments.end()) + return this; + + auto orExpression = OrPointer(dynamic_cast(match->get())); + const size_t orExpressionIndex = match - m_arguments.begin(); + + // Apply the distributive law + // Copy this and expression for each argument of the or expression + for (size_t i = 0; i < orExpression->arguments().size(); i++) + { + auto newAndExpression = new expressions::And; + newAndExpression->arguments().resize(m_arguments.size()); + + for (size_t j = 0; j < m_arguments.size(); j++) + { + if (j == orExpressionIndex) + newAndExpression->arguments()[j] = orExpression->arguments()[i]->copy(); + else + newAndExpression->arguments()[j] = m_arguments[j]->copy(); + } + + // Replace the respective argument with the new, recursively normalized and expression + orExpression->arguments()[i] = newAndExpression->disjunctionNormalized(); + } + + return orExpression; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } } } diff --git a/src/plasp/pddl/expressions/At.cpp b/src/plasp/pddl/expressions/At.cpp index ec92601..54ce459 100644 --- a/src/plasp/pddl/expressions/At.cpp +++ b/src/plasp/pddl/expressions/At.cpp @@ -22,6 +22,17 @@ At::At() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer At::copy() +{ + auto result = new At; + + result->m_argument = m_argument->copy(); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void At::setArgument(ExpressionPointer argument) { m_argument = argument; @@ -76,6 +87,17 @@ ExpressionPointer At::simplified() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer At::disjunctionNormalized() +{ + BOOST_ASSERT(m_argument); + + m_argument = m_argument->disjunctionNormalized(); + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void At::print(std::ostream &ostream) const { ostream << "(at " << m_timePoint << " "; diff --git a/src/plasp/pddl/expressions/Not.cpp b/src/plasp/pddl/expressions/Not.cpp index a002a73..4234e01 100644 --- a/src/plasp/pddl/expressions/Not.cpp +++ b/src/plasp/pddl/expressions/Not.cpp @@ -25,6 +25,17 @@ Not::Not() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Not::copy() +{ + auto result = new Not; + + result->m_argument = m_argument->copy(); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void Not::setArgument(ExpressionPointer argument) { m_argument = argument; @@ -150,6 +161,17 @@ ExpressionPointer Not::simplified() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Not::disjunctionNormalized() +{ + BOOST_ASSERT(m_argument); + + m_argument = m_argument->disjunctionNormalized(); + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void Not::print(std::ostream &ostream) const { ostream << "(not "; diff --git a/tests/TestPDDLNormalization.cpp b/tests/TestPDDLNormalization.cpp index 7c28b08..28db35a 100644 --- a/tests/TestPDDLNormalization.cpp +++ b/tests/TestPDDLNormalization.cpp @@ -302,3 +302,51 @@ TEST(PDDLNormalizationTests, PrenexGroupSameType) ASSERT_EQ(output.str(), "(forall (?v1 ?v2 ?v6 ?v7) (exists (?v3 ?v8) (forall (?v4 ?v9) (exists (?v5) (and (a) (b))))))"); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(PDDLNormalizationTests, DisjunctiveNormalForm) +{ + auto f = expressions::ForAllPointer(new expressions::ForAll); + auto e = expressions::ExistsPointer(new expressions::Exists); + auto a = expressions::AndPointer(new expressions::And); + auto o1 = expressions::OrPointer(new expressions::Or); + auto o2 = expressions::OrPointer(new expressions::Or); + auto o3 = expressions::OrPointer(new expressions::Or); + + f->variables() = {new expressions::Variable("v1")}; + f->setArgument(e); + + e->variables() = {new expressions::Variable("v2")}; + e->setArgument(o1); + + o1->addArgument(a); + o1->addArgument(new expressions::Dummy("h")); + + a->addArgument(new expressions::Dummy("a")); + a->addArgument(new expressions::Dummy("b")); + a->addArgument(o2); + a->addArgument(o3); + + o2->addArgument(new expressions::Dummy("c")); + o2->addArgument(new expressions::Dummy("d")); + o2->addArgument(new expressions::Dummy("e")); + + o3->addArgument(new expressions::Dummy("f")); + o3->addArgument(new expressions::Dummy("g")); + + auto normalized = f->normalized(); + + std::stringstream output; + normalized->print(output); + + ASSERT_EQ(output.str(), "(forall (?v1) (exists (?v2) (or " + "(and (a) (b) (c) (f)) " + "(h) " + "(and (a) (b) (d) (f)) " + "(and (a) (b) (e) (f)) " + "(and (a) (b) (c) (g)) " + "(and (a) (b) (d) (g)) " + "(and (a) (b) (e) (g))" + ")))"); +}