diff --git a/include/anthem/AST.h b/include/anthem/AST.h index bc97e47..1c49ed0 100644 --- a/include/anthem/AST.h +++ b/include/anthem/AST.h @@ -148,6 +148,7 @@ struct FunctionDeclaration std::string name; size_t arity; + Domain domain{Domain::Noninteger}; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/anthem/StatementVisitor.h b/include/anthem/StatementVisitor.h index c6517d2..1701113 100644 --- a/include/anthem/StatementVisitor.h +++ b/include/anthem/StatementVisitor.h @@ -190,7 +190,7 @@ struct StatementVisitor const auto fail = [&]() { - throw LogicException(statement.location, "only #external declarations of the form “#external ().” supported"); + throw LogicException(statement.location, "only #external declarations of the form “#external ().” or “#external integer(()).” supported"); }; if (!external.body.empty()) @@ -204,6 +204,47 @@ struct StatementVisitor if (predicate.arguments.size() != 1) fail(); + const auto handleIntegerDeclaration = + [&]() + { + // Integer function declarations are treated separately if applicable + if (strcmp(predicate.name, "integer") != 0) + return false; + + if (predicate.arguments.size() != 1) + return false; + + const auto &functionArgument = predicate.arguments.front(); + + if (!functionArgument.data.is()) + return false; + + const auto &function = functionArgument.data.get(); + + if (function.arguments.size() != 1) + return false; + + const auto &arityArgument = function.arguments.front(); + + if (!arityArgument.data.is()) + return false; + + const auto &aritySymbol = arityArgument.data.get(); + + if (aritySymbol.type() != Clingo::SymbolType::Number) + return false; + + const size_t arity = aritySymbol.number(); + + auto functionDeclaration = context.findOrCreateFunctionDeclaration(function.name, arity); + functionDeclaration->domain = Domain::Integer; + + return true; + }; + + if (handleIntegerDeclaration()) + return; + const auto &arityArgument = predicate.arguments.front(); if (!arityArgument.data.is())