diff --git a/include/plasp/pddl/Domain.h b/include/plasp/pddl/Domain.h index 81ea1ce..d908699 100644 --- a/include/plasp/pddl/Domain.h +++ b/include/plasp/pddl/Domain.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,9 @@ class Domain std::vector> &actions(); const std::vector> &actions() const; + expressions::DerivedPredicates &derivedPredicates(); + const expressions::DerivedPredicates &derivedPredicates() const; + void checkConsistency(); void normalize(); @@ -85,6 +89,8 @@ class Domain std::vector m_actionPositions; std::vector> m_actions; + + expressions::DerivedPredicates m_derivedPredicates; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/plasp/pddl/Expression.h b/include/plasp/pddl/Expression.h index 07b12e2..aeb1d4e 100644 --- a/include/plasp/pddl/Expression.h +++ b/include/plasp/pddl/Expression.h @@ -40,6 +40,10 @@ class Constant; using ConstantPointer = boost::intrusive_ptr; using Constants = std::vector; +class DerivedPredicate; +using DerivedPredicatePointer = boost::intrusive_ptr; +using DerivedPredicates = std::vector; + class Dummy; using DummyPointer = boost::intrusive_ptr; @@ -103,6 +107,7 @@ class Expression At, Binary, Constant, + DerivedPredicate, Dummy, Either, Exists, @@ -132,6 +137,7 @@ class Expression virtual ExpressionPointer prenex(Expression::Type lastQuantifierType = Expression::Type::Exists); virtual ExpressionPointer simplified(); virtual ExpressionPointer disjunctionNormalized(); + virtual ExpressionPointer decomposed(expressions::DerivedPredicates &derivedPredicates); ExpressionPointer negated(); virtual void print(std::ostream &ostream) const = 0; diff --git a/include/plasp/pddl/expressions/And.h b/include/plasp/pddl/expressions/And.h index 505a5f0..97d885b 100644 --- a/include/plasp/pddl/expressions/And.h +++ b/include/plasp/pddl/expressions/And.h @@ -25,6 +25,7 @@ class And: public NAry public: ExpressionPointer disjunctionNormalized() override; + ExpressionPointer decomposed(DerivedPredicates &derivedPredicates) override; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/plasp/pddl/expressions/DerivedPredicate.h b/include/plasp/pddl/expressions/DerivedPredicate.h new file mode 100644 index 0000000..53087f7 --- /dev/null +++ b/include/plasp/pddl/expressions/DerivedPredicate.h @@ -0,0 +1,42 @@ +#ifndef __PLASP__PDDL__EXPRESSIONS__DERIVED_PREDICATE_H +#define __PLASP__PDDL__EXPRESSIONS__DERIVED_PREDICATE_H + +#include + +namespace plasp +{ +namespace pddl +{ +namespace expressions +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DerivedPredicate +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +class DerivedPredicate: public ExpressionCRTP +{ + public: + static const Expression::Type ExpressionType = Expression::Type::DerivedPredicate; + + // TODO: consider implementing parsing functions for compatibility with older PDDL versions + + public: + void setArgument(ExpressionPointer argument); + ExpressionPointer argument() const; + + void print(std::ostream &ostream) const override; + + private: + ExpressionPointer m_argument; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +#endif diff --git a/include/plasp/pddl/expressions/Exists.h b/include/plasp/pddl/expressions/Exists.h index 815b34f..6f05108 100644 --- a/include/plasp/pddl/expressions/Exists.h +++ b/include/plasp/pddl/expressions/Exists.h @@ -22,6 +22,9 @@ class Exists: public QuantifiedCRTP static const Expression::Type ExpressionType = Expression::Type::Exists; static const std::string Identifier; + + public: + ExpressionPointer decomposed(DerivedPredicates &derivedPredicates) override; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/plasp/pddl/expressions/ForAll.h b/include/plasp/pddl/expressions/ForAll.h index 2c33be4..369f601 100644 --- a/include/plasp/pddl/expressions/ForAll.h +++ b/include/plasp/pddl/expressions/ForAll.h @@ -22,6 +22,9 @@ class ForAll: public QuantifiedCRTP static const Expression::Type ExpressionType = Expression::Type::ForAll; static const std::string Identifier; + + public: + ExpressionPointer decomposed(DerivedPredicates &derivedPredicates) override; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/plasp/pddl/expressions/Not.h b/include/plasp/pddl/expressions/Not.h index e3785f7..600bd65 100644 --- a/include/plasp/pddl/expressions/Not.h +++ b/include/plasp/pddl/expressions/Not.h @@ -39,6 +39,7 @@ class Not: public ExpressionCRTP ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer simplified() override; ExpressionPointer disjunctionNormalized() override; + ExpressionPointer decomposed(DerivedPredicates &derivedPredicates) override; void print(std::ostream &ostream) const override; diff --git a/include/plasp/pddl/expressions/Or.h b/include/plasp/pddl/expressions/Or.h index ac75c2e..2d78071 100644 --- a/include/plasp/pddl/expressions/Or.h +++ b/include/plasp/pddl/expressions/Or.h @@ -22,6 +22,9 @@ class Or: public NAry static const Expression::Type ExpressionType = Expression::Type::Or; static const std::string Identifier; + + public: + ExpressionPointer decomposed(DerivedPredicates &derivedPredicates) override; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plasp/pddl/Domain.cpp b/src/plasp/pddl/Domain.cpp index ef10936..469cabd 100644 --- a/src/plasp/pddl/Domain.cpp +++ b/src/plasp/pddl/Domain.cpp @@ -232,6 +232,20 @@ const std::vector> &Domain::actions() const //////////////////////////////////////////////////////////////////////////////////////////////////// +expressions::DerivedPredicates &Domain::derivedPredicates() +{ + return m_derivedPredicates; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +const expressions::DerivedPredicates &Domain::derivedPredicates() const +{ + return m_derivedPredicates; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + void Domain::parseRequirementSection() { auto &parser = m_context.parser; diff --git a/src/plasp/pddl/Expression.cpp b/src/plasp/pddl/Expression.cpp index 4539e42..e463664 100644 --- a/src/plasp/pddl/Expression.cpp +++ b/src/plasp/pddl/Expression.cpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace plasp { @@ -112,6 +113,13 @@ ExpressionPointer Expression::disjunctionNormalized() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Expression::decomposed(expressions::DerivedPredicates &) +{ + throw utils::TranslatorException("Expression is not in first-order negation normal form and cannot be decomposed"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + ExpressionPointer Expression::negated() { if (expressionType() == Type::Not) diff --git a/src/plasp/pddl/expressions/And.cpp b/src/plasp/pddl/expressions/And.cpp index a0a83ce..d2de6bc 100644 --- a/src/plasp/pddl/expressions/And.cpp +++ b/src/plasp/pddl/expressions/And.cpp @@ -4,6 +4,7 @@ #include #include +#include namespace plasp { @@ -67,6 +68,29 @@ ExpressionPointer And::disjunctionNormalized() //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer And::decomposed(DerivedPredicates &derivedPredicates) +{ + // Check that all children are simple or negated predicates + std::for_each(m_arguments.begin(), m_arguments.end(), + [&](auto &argument) + { + if (argument->expressionType() == Expression::Type::Not) + { + argument = argument->decomposed(derivedPredicates); + return; + } + + if (argument->expressionType() != Expression::Type::Predicate) + return; + + throw utils::TranslatorException("Expression is not in first-order negation normal form and cannot be decomposed"); + }); + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } } } diff --git a/src/plasp/pddl/expressions/DerivedPredicate.cpp b/src/plasp/pddl/expressions/DerivedPredicate.cpp new file mode 100644 index 0000000..c211c9b --- /dev/null +++ b/src/plasp/pddl/expressions/DerivedPredicate.cpp @@ -0,0 +1,44 @@ +#include + +#include +#include + +namespace plasp +{ +namespace pddl +{ +namespace expressions +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DerivedPredicate +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void DerivedPredicate::setArgument(ExpressionPointer argument) +{ + m_argument = argument; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +ExpressionPointer DerivedPredicate::argument() const +{ + return m_argument; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void DerivedPredicate::print(std::ostream &ostream) const +{ + ostream << "(:derived "; + m_argument->print(ostream); + ostream << ")"; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} diff --git a/src/plasp/pddl/expressions/Exists.cpp b/src/plasp/pddl/expressions/Exists.cpp index 5169f5e..eccd1dc 100644 --- a/src/plasp/pddl/expressions/Exists.cpp +++ b/src/plasp/pddl/expressions/Exists.cpp @@ -20,6 +20,15 @@ const std::string Exists::Identifier = "exists"; //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Exists::decomposed(DerivedPredicates &derivedPredicates) +{ + m_argument = m_argument->decomposed(derivedPredicates); + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } } } diff --git a/src/plasp/pddl/expressions/ForAll.cpp b/src/plasp/pddl/expressions/ForAll.cpp index 566d516..4013f7f 100644 --- a/src/plasp/pddl/expressions/ForAll.cpp +++ b/src/plasp/pddl/expressions/ForAll.cpp @@ -3,6 +3,8 @@ #include #include +#include + namespace plasp { namespace pddl @@ -20,6 +22,18 @@ const std::string ForAll::Identifier = "forall"; //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer ForAll::decomposed(DerivedPredicates &derivedPredicates) +{ + auto derivedPredicate = DerivedPredicatePointer(new DerivedPredicate); + derivedPredicates.push_back(derivedPredicate); + + derivedPredicate->setArgument(this); + + return derivedPredicate; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } } } diff --git a/src/plasp/pddl/expressions/Not.cpp b/src/plasp/pddl/expressions/Not.cpp index 4234e01..19066cd 100644 --- a/src/plasp/pddl/expressions/Not.cpp +++ b/src/plasp/pddl/expressions/Not.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace plasp { @@ -183,6 +184,19 @@ void Not::print(std::ostream &ostream) const //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Not::decomposed(DerivedPredicates &) +{ + if (m_argument->expressionType() != Expression::Type::Not + && m_argument->expressionType() != Expression::Type::Predicate) + { + throw utils::TranslatorException("Expression is not in first-order negation normal form and cannot be decomposed"); + } + + return this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } } } diff --git a/src/plasp/pddl/expressions/Or.cpp b/src/plasp/pddl/expressions/Or.cpp index 67d027b..5fbedb3 100644 --- a/src/plasp/pddl/expressions/Or.cpp +++ b/src/plasp/pddl/expressions/Or.cpp @@ -1,5 +1,7 @@ #include +#include + namespace plasp { namespace pddl @@ -17,6 +19,25 @@ const std::string Or::Identifier = "or"; //////////////////////////////////////////////////////////////////////////////////////////////////// +ExpressionPointer Or::decomposed(DerivedPredicates &derivedPredicates) +{ + // Check that all children are simple or negated predicates + std::for_each(m_arguments.begin(), m_arguments.end(), + [&](auto &argument) + { + argument = argument->decomposed(derivedPredicates); + }); + + auto derivedPredicate = DerivedPredicatePointer(new DerivedPredicate); + derivedPredicates.push_back(derivedPredicate); + + derivedPredicate->setArgument(this); + + return derivedPredicate; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } } }