diff --git a/lib/pddlparse/src/pddlparse/detail/parsing/PrimitiveTypeDeclaration.cpp b/lib/pddlparse/src/pddlparse/detail/parsing/PrimitiveTypeDeclaration.cpp index 87cb6a5..6a19d6f 100644 --- a/lib/pddlparse/src/pddlparse/detail/parsing/PrimitiveTypeDeclaration.cpp +++ b/lib/pddlparse/src/pddlparse/detail/parsing/PrimitiveTypeDeclaration.cpp @@ -15,11 +15,8 @@ namespace detail // //////////////////////////////////////////////////////////////////////////////////////////////////// -ast::PrimitiveTypeDeclarationPointer &parseAndAddUntypedPrimitiveTypeDeclaration(Context &context, ast::Domain &domain) +std::experimental::optional findPrimitiveTypeDeclaration(ast::Domain &domain, const std::string &typeName) { - auto &tokenizer = context.tokenizer; - auto typeName = tokenizer.getIdentifier(); - auto &types = domain.types; const auto matchingPrimitiveType = std::find_if(types.begin(), types.end(), @@ -28,11 +25,28 @@ ast::PrimitiveTypeDeclarationPointer &parseAndAddUntypedPrimitiveTypeDeclaration return primitiveType->name == typeName; }); - // Return existing primitive type - if (matchingPrimitiveType != types.cend()) - return *matchingPrimitiveType; + if (matchingPrimitiveType != types.end()) + return &*matchingPrimitiveType; + + return std::experimental::nullopt; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +ast::PrimitiveTypeDeclarationPointer &parseAndAddUntypedPrimitiveTypeDeclaration(Context &context, ast::Domain &domain, std::vector &flaggedTypes) +{ + auto &tokenizer = context.tokenizer; + auto typeName = tokenizer.getIdentifier(); + + auto &types = domain.types; + + auto matchingPrimitiveTypeDeclaration = findPrimitiveTypeDeclaration(domain, typeName); + + if (matchingPrimitiveTypeDeclaration) + return *matchingPrimitiveTypeDeclaration.value(); types.emplace_back(std::make_unique(std::move(typeName))); + flaggedTypes.emplace_back(false); return types.back(); } @@ -44,13 +58,15 @@ void parseAndAddPrimitiveTypeDeclarations(Context &context, ast::Domain &domain) auto &tokenizer = context.tokenizer; tokenizer.skipWhiteSpace(); - const auto position = tokenizer.position(); - const auto typeStartIndex = domain.types.size(); + auto &types = domain.types; + + std::vector flaggedTypes; + flaggedTypes.resize(types.size(), false); - // First pass: collect all primitive types while (tokenizer.currentCharacter() != ')') { - parseAndAddUntypedPrimitiveTypeDeclaration(context, domain); + auto &childType = parseAndAddUntypedPrimitiveTypeDeclaration(context, domain, flaggedTypes); + flaggedTypes[&childType - &types.front()] = true; tokenizer.skipWhiteSpace(); @@ -58,43 +74,17 @@ void parseAndAddPrimitiveTypeDeclarations(Context &context, ast::Domain &domain) continue; // Skip parent type information for now - tokenizer.getIdentifier(); + auto &parentType = parseAndAddUntypedPrimitiveTypeDeclaration(context, domain, flaggedTypes); + + for (size_t i = 0; i < flaggedTypes.size(); i++) + if (flaggedTypes[i]) + { + flaggedTypes[i] = false; + types[i]->parentTypes.emplace_back(std::make_unique(parentType.get())); + } + tokenizer.skipWhiteSpace(); } - - tokenizer.seek(position); - - // Second pass: link parent types correctly - // Index on the first element of the current inheritance list - // TODO: test correct implementation of offset if this function is called multiple times - size_t inheritanceIndex = typeStartIndex; - size_t i = typeStartIndex; - - while (tokenizer.currentCharacter() != ')') - { - // Skip type declaration - tokenizer.getIdentifier(); - tokenizer.skipWhiteSpace(); - - if (!tokenizer.testAndSkip('-')) - { - i++; - continue; - } - - // If existing, parse and store parent type - auto parentType = parsePrimitiveType(context, domain); - tokenizer.skipWhiteSpace(); - - auto &types = domain.types; - - for (size_t j = inheritanceIndex; j <= i; j++) - types[j]->parentTypes.emplace_back(ast::deepCopy(parentType)); - - // All types up to now are labeled with their parent types - inheritanceIndex = i + 1; - i++; - } } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/lib/pddlparse/tests/TestOfficialPDDLInstances.cpp b/lib/pddlparse/tests/TestOfficialPDDLInstances.cpp index a02613a..4350798 100644 --- a/lib/pddlparse/tests/TestOfficialPDDLInstances.cpp +++ b/lib/pddlparse/tests/TestOfficialPDDLInstances.cpp @@ -254,25 +254,25 @@ TEST_CASE("[PDDL instances] The official PDDL instances are parsed correctly", " CHECK(types[1]->name == "airplane"); REQUIRE(types[1]->parentTypes.size() == 1); CHECK(types[1]->parentTypes[0]->declaration->name == "vehicle"); - CHECK(types[2]->name == "package"); + CHECK(types[2]->name == "vehicle"); REQUIRE(types[2]->parentTypes.size() == 1); CHECK(types[2]->parentTypes[0]->declaration->name == "physobj"); - CHECK(types[3]->name == "vehicle"); + CHECK(types[3]->name == "package"); REQUIRE(types[3]->parentTypes.size() == 1); CHECK(types[3]->parentTypes[0]->declaration->name == "physobj"); - CHECK(types[4]->name == "airport"); + CHECK(types[4]->name == "physobj"); REQUIRE(types[4]->parentTypes.size() == 1); - CHECK(types[4]->parentTypes[0]->declaration->name == "place"); - CHECK(types[5]->name == "location"); + CHECK(types[4]->parentTypes[0]->declaration->name == "object"); + CHECK(types[5]->name == "airport"); REQUIRE(types[5]->parentTypes.size() == 1); CHECK(types[5]->parentTypes[0]->declaration->name == "place"); - CHECK(types[6]->name == "city"); + CHECK(types[6]->name == "location"); REQUIRE(types[6]->parentTypes.size() == 1); - CHECK(types[6]->parentTypes[0]->declaration->name == "object"); + CHECK(types[6]->parentTypes[0]->declaration->name == "place"); CHECK(types[7]->name == "place"); REQUIRE(types[7]->parentTypes.size() == 1); CHECK(types[7]->parentTypes[0]->declaration->name == "object"); - CHECK(types[8]->name == "physobj"); + CHECK(types[8]->name == "city"); REQUIRE(types[8]->parentTypes.size() == 1); CHECK(types[8]->parentTypes[0]->declaration->name == "object"); CHECK(types[9]->name == "object"); diff --git a/lib/pddlparse/tests/TestParser.cpp b/lib/pddlparse/tests/TestParser.cpp new file mode 100644 index 0000000..3129507 --- /dev/null +++ b/lib/pddlparse/tests/TestParser.cpp @@ -0,0 +1,47 @@ +#include + +#include + +#include +#include + +namespace fs = std::experimental::filesystem; + +const pddl::Context::WarningCallback ignoreWarnings = [](const auto &, const auto &warning){std::cout << warning << std::endl;}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST_CASE("[PDDL parser] Check past issues", "[PDDL parser]") +{ + pddl::Tokenizer tokenizer; + pddl::Context context(std::move(tokenizer), ignoreWarnings); + + // Check that no infinite loop occurs + SECTION("“either” in typing section") + { + const auto domainFile = fs::path("data") / "test-cases" / "typing-1.pddl"; + context.tokenizer.read(domainFile); + const auto description = pddl::parseDescription(context); + + const auto &types = description.domain->types; + + REQUIRE(types.size() == 5); + CHECK(types[0]->name == "object"); + REQUIRE(types[0]->parentTypes.size() == 1); + CHECK(types[0]->parentTypes[0]->declaration == types[0].get()); + CHECK(types[1]->name == "a1"); + REQUIRE(types[1]->parentTypes.size() == 1); + CHECK(types[1]->parentTypes[0]->declaration == types[0].get()); + CHECK(types[2]->name == "a2"); + REQUIRE(types[2]->parentTypes.size() == 1); + CHECK(types[2]->parentTypes[0]->declaration == types[0].get()); + CHECK(types[3]->name == "a3"); + REQUIRE(types[3]->parentTypes.size() == 1); + CHECK(types[3]->parentTypes[0]->declaration == types[0].get()); + CHECK(types[4]->name == "bx"); + REQUIRE(types[4]->parentTypes.size() == 3); + CHECK(types[4]->parentTypes[0]->declaration == types[1].get()); + CHECK(types[4]->parentTypes[1]->declaration == types[2].get()); + CHECK(types[4]->parentTypes[2]->declaration == types[3].get()); + } +} diff --git a/tests/data/test-cases/typing-1.pddl b/tests/data/test-cases/typing-1.pddl new file mode 100644 index 0000000..07ff4b4 --- /dev/null +++ b/tests/data/test-cases/typing-1.pddl @@ -0,0 +1,11 @@ +(define + (domain test) + (:requirements :typing) + (:types + object + a1 a2 a3 - object + bx - a1 + bx - a2 + bx - a3 + ) +)