diff --git a/include/anthem/AST.h b/include/anthem/AST.h index 8a770ec..41b0865 100644 --- a/include/anthem/AST.h +++ b/include/anthem/AST.h @@ -105,13 +105,13 @@ struct Comparison struct Function { - explicit Function(std::string &&name) - : name{std::move(name)} + explicit Function(FunctionDeclaration *declaration) + : declaration{declaration} { } - explicit Function(std::string &&name, std::vector &&arguments) - : name{std::move(name)}, + explicit Function(FunctionDeclaration *declaration, std::vector &&arguments) + : declaration{declaration}, arguments{std::move(arguments)} { } @@ -121,12 +121,36 @@ struct Function Function(Function &&other) noexcept = default; Function &operator=(Function &&other) noexcept = default; - std::string name; + FunctionDeclaration *declaration; std::vector arguments; }; //////////////////////////////////////////////////////////////////////////////////////////////////// +struct FunctionDeclaration +{ + explicit FunctionDeclaration(std::string &&name) + : name{std::move(name)} + { + } + + explicit FunctionDeclaration(std::string &&name, size_t arity) + : name{std::move(name)}, + arity{arity} + { + } + + FunctionDeclaration(const FunctionDeclaration &other) = delete; + FunctionDeclaration &operator=(const FunctionDeclaration &other) = delete; + FunctionDeclaration(FunctionDeclaration &&other) noexcept = default; + FunctionDeclaration &operator=(FunctionDeclaration &&other) noexcept = default; + + std::string name; + size_t arity; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // TODO: refactor (limit element type to primitive terms) struct In { diff --git a/include/anthem/ASTForward.h b/include/anthem/ASTForward.h index 10c1a65..f78f465 100644 --- a/include/anthem/ASTForward.h +++ b/include/anthem/ASTForward.h @@ -27,6 +27,7 @@ struct Comparison; struct Exists; struct ForAll; struct Function; +struct FunctionDeclaration; struct Implies; struct In; struct Integer; diff --git a/include/anthem/Body.h b/include/anthem/Body.h index 9995c86..2233eb1 100644 --- a/include/anthem/Body.h +++ b/include/anthem/Body.h @@ -76,7 +76,7 @@ struct BodyTermTranslateVisitor for (size_t i = 0; i < function.arguments.size(); i++) { auto &argument = function.arguments[i]; - conjunction.arguments.emplace_back(ast::Formula::make(ast::Variable(parameters[i].get()), translate(argument, ruleContext, variableStack))); + conjunction.arguments.emplace_back(ast::Formula::make(ast::Variable(parameters[i].get()), translate(argument, ruleContext, context, variableStack))); } variableStack.pop(); @@ -121,7 +121,7 @@ struct BodyLiteralTranslateVisitor } // TODO: refactor - std::optional visit(const Clingo::AST::Comparison &comparison, const Clingo::AST::Literal &literal, RuleContext &ruleContext, Context &, ast::VariableStack &variableStack) + std::optional visit(const Clingo::AST::Comparison &comparison, const Clingo::AST::Literal &literal, RuleContext &ruleContext, Context &context, ast::VariableStack &variableStack) { // Comparisons should never have a sign, because these are converted to positive comparisons by clingo if (literal.sign != Clingo::AST::Sign::None) @@ -136,8 +136,8 @@ struct BodyLiteralTranslateVisitor ast::And conjunction; conjunction.arguments.reserve(3); - conjunction.arguments.emplace_back(ast::Formula::make(ast::Variable(parameters[0].get()), translate(comparison.left, ruleContext, variableStack))); - conjunction.arguments.emplace_back(ast::Formula::make(ast::Variable(parameters[1].get()), translate(comparison.right, ruleContext, variableStack))); + conjunction.arguments.emplace_back(ast::Formula::make(ast::Variable(parameters[0].get()), translate(comparison.left, ruleContext, context, variableStack))); + conjunction.arguments.emplace_back(ast::Formula::make(ast::Variable(parameters[1].get()), translate(comparison.right, ruleContext, context, variableStack))); conjunction.arguments.emplace_back(ast::Formula::make(operator_, ast::Variable(parameters[0].get()), ast::Variable(parameters[1].get()))); return ast::Formula::make(std::move(parameters), std::move(conjunction)); diff --git a/include/anthem/Context.h b/include/anthem/Context.h index 24acbdc..06cadda 100644 --- a/include/anthem/Context.h +++ b/include/anthem/Context.h @@ -53,6 +53,26 @@ struct Context return predicateDeclarations.back().get(); } + ast::FunctionDeclaration *findOrCreateFunctionDeclaration(const char *name, size_t arity) + { + const auto matchesExistingFunctionDeclaration = + [&](const auto &functionDeclarations) + { + return (functionDeclarations->arity == arity + && strcmp(functionDeclarations->name.c_str(), name) == 0); + }; + + auto matchingFunctionDeclaration = std::find_if(functionDeclarations.begin(), + functionDeclarations.end(), matchesExistingFunctionDeclaration); + + if (matchingFunctionDeclaration != functionDeclarations.end()) + return matchingFunctionDeclaration->get(); + + functionDeclarations.emplace_back(std::make_unique(name, arity)); + + return functionDeclarations.back().get(); + } + output::Logger logger; bool performSimplification{false}; @@ -61,6 +81,8 @@ struct Context std::vector> predicateDeclarations; ast::PredicateDeclaration::Visibility defaultPredicateVisibility{ast::PredicateDeclaration::Visibility::Visible}; + std::vector> functionDeclarations; + bool externalStatementsUsed{false}; bool showStatementsUsed{false}; diff --git a/include/anthem/Equality.h b/include/anthem/Equality.h index 7c5c83b..683bef6 100644 --- a/include/anthem/Equality.h +++ b/include/anthem/Equality.h @@ -305,7 +305,7 @@ struct TermEqualityVisitor const auto &otherFunction = otherTerm.get(); - if (function.name != otherFunction.name) + if (function.declaration != otherFunction.declaration) return Tristate::False; if (function.arguments.size() != otherFunction.arguments.size()) diff --git a/include/anthem/StatementVisitor.h b/include/anthem/StatementVisitor.h index bd0d21d..c6517d2 100644 --- a/include/anthem/StatementVisitor.h +++ b/include/anthem/StatementVisitor.h @@ -87,7 +87,7 @@ struct StatementVisitor const auto auxiliaryHeadVariableID = ruleContext.headVariablesStartIndex + i - ruleContext.headTerms.cbegin(); auto element = ast::Variable(ruleContext.freeVariables[auxiliaryHeadVariableID].get()); - auto set = translate(headTerm, ruleContext, variableStack); + auto set = translate(headTerm, ruleContext, context, variableStack); auto in = ast::In(std::move(element), std::move(set)); antecedent.arguments.emplace_back(std::move(in)); diff --git a/include/anthem/Term.h b/include/anthem/Term.h index 5c2cbd7..62c64ba 100644 --- a/include/anthem/Term.h +++ b/include/anthem/Term.h @@ -65,13 +65,13 @@ ast::UnaryOperation::Operator translate(Clingo::AST::UnaryOperator unaryOperator //////////////////////////////////////////////////////////////////////////////////////////////////// -ast::Term translate(const Clingo::AST::Term &term, RuleContext &ruleContext, const ast::VariableStack &variableStack); +ast::Term translate(const Clingo::AST::Term &term, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack); //////////////////////////////////////////////////////////////////////////////////////////////////// struct TermTranslateVisitor { - std::optional visit(const Clingo::Symbol &symbol, const Clingo::AST::Term &term, RuleContext &ruleContext, const ast::VariableStack &variableStack) + std::optional visit(const Clingo::Symbol &symbol, const Clingo::AST::Term &term, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack) { switch (symbol.type()) { @@ -85,19 +85,19 @@ struct TermTranslateVisitor return ast::Term::make(std::string(symbol.string())); case Clingo::SymbolType::Function: { - auto function = ast::Term::make(symbol.name()); - // TODO: remove workaround - auto &functionRaw = function.get(); - functionRaw.arguments.reserve(symbol.arguments().size()); + auto functionDeclaration = context.findOrCreateFunctionDeclaration(symbol.name(), symbol.arguments().size()); + + auto function = ast::Function(functionDeclaration); + function.arguments.reserve(symbol.arguments().size()); for (const auto &argument : symbol.arguments()) { - auto translatedArgument = visit(argument, term, ruleContext, variableStack); + auto translatedArgument = visit(argument, term, ruleContext, context, variableStack); if (!translatedArgument) throw TranslationException(term.location, "could not translate argument"); - functionRaw.arguments.emplace_back(std::move(translatedArgument.value())); + function.arguments.emplace_back(std::move(translatedArgument.value())); } return std::move(function); @@ -107,7 +107,7 @@ struct TermTranslateVisitor return std::nullopt; } - std::optional visit(const Clingo::AST::Variable &variable, const Clingo::AST::Term &, RuleContext &ruleContext, const ast::VariableStack &variableStack) + std::optional visit(const Clingo::AST::Variable &variable, const Clingo::AST::Term &, RuleContext &ruleContext, Context &, const ast::VariableStack &variableStack) { const auto matchingVariableDeclaration = variableStack.findUserVariableDeclaration(variable.name); const auto isAnonymousVariable = (strcmp(variable.name, "_") == 0); @@ -120,35 +120,36 @@ struct TermTranslateVisitor auto variableDeclaration = std::make_unique(ast::VariableDeclaration::Type::UserDefined, std::string(variable.name)); ruleContext.freeVariables.emplace_back(std::move(variableDeclaration)); + // TODO: ast::Term::make is unnecessary and can be removed return ast::Term::make(ruleContext.freeVariables.back().get()); } - std::optional visit(const Clingo::AST::BinaryOperation &binaryOperation, const Clingo::AST::Term &term, RuleContext &ruleContext, const ast::VariableStack &variableStack) + std::optional visit(const Clingo::AST::BinaryOperation &binaryOperation, const Clingo::AST::Term &term, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack) { const auto operator_ = translate(binaryOperation.binary_operator, term); - auto left = translate(binaryOperation.left, ruleContext, variableStack); - auto right = translate(binaryOperation.right, ruleContext, variableStack); + auto left = translate(binaryOperation.left, ruleContext, context, variableStack); + auto right = translate(binaryOperation.right, ruleContext, context, variableStack); return ast::Term::make(operator_, std::move(left), std::move(right)); } - std::optional visit(const Clingo::AST::UnaryOperation &unaryOperation, const Clingo::AST::Term &term, RuleContext &ruleContext, const ast::VariableStack &variableStack) + std::optional visit(const Clingo::AST::UnaryOperation &unaryOperation, const Clingo::AST::Term &term, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack) { const auto operator_ = translate(unaryOperation.unary_operator, term); - auto argument = translate(unaryOperation.argument, ruleContext, variableStack); + auto argument = translate(unaryOperation.argument, ruleContext, context, variableStack); return ast::Term::make(operator_, std::move(argument)); } - std::optional visit(const Clingo::AST::Interval &interval, const Clingo::AST::Term &, RuleContext &ruleContext, const ast::VariableStack &variableStack) + std::optional visit(const Clingo::AST::Interval &interval, const Clingo::AST::Term &, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack) { - auto left = translate(interval.left, ruleContext, variableStack); - auto right = translate(interval.right, ruleContext, variableStack); + auto left = translate(interval.left, ruleContext, context, variableStack); + auto right = translate(interval.right, ruleContext, context, variableStack); return ast::Term::make(std::move(left), std::move(right)); } - std::optional visit(const Clingo::AST::Function &function, const Clingo::AST::Term &term, RuleContext &ruleContext, const ast::VariableStack &variableStack) + std::optional visit(const Clingo::AST::Function &function, const Clingo::AST::Term &term, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack) { if (function.external) throw TranslationException(term.location, "external functions currently unsupported"); @@ -157,12 +158,14 @@ struct TermTranslateVisitor arguments.reserve(function.arguments.size()); for (const auto &argument : function.arguments) - arguments.emplace_back(translate(argument, ruleContext, variableStack)); + arguments.emplace_back(translate(argument, ruleContext, context, variableStack)); - return ast::Term::make(function.name, std::move(arguments)); + auto functionDeclaration = context.findOrCreateFunctionDeclaration(function.name, function.arguments.size()); + + return ast::Term::make(functionDeclaration, std::move(arguments)); } - std::optional visit(const Clingo::AST::Pool &, const Clingo::AST::Term &term, RuleContext &, const ast::VariableStack &) + std::optional visit(const Clingo::AST::Pool &, const Clingo::AST::Term &term, RuleContext &, Context &, const ast::VariableStack &) { throw TranslationException(term.location, "“pool” terms currently unsupported"); return std::nullopt; @@ -171,9 +174,9 @@ struct TermTranslateVisitor //////////////////////////////////////////////////////////////////////////////////////////////////// -ast::Term translate(const Clingo::AST::Term &term, RuleContext &ruleContext, const ast::VariableStack &variableStack) +ast::Term translate(const Clingo::AST::Term &term, RuleContext &ruleContext, Context &context, const ast::VariableStack &variableStack) { - auto translatedTerm = term.data.accept(TermTranslateVisitor(), term, ruleContext, variableStack); + auto translatedTerm = term.data.accept(TermTranslateVisitor(), term, ruleContext, context, variableStack); if (!translatedTerm) throw TranslationException(term.location, "could not translate term"); diff --git a/include/anthem/output/AST.h b/include/anthem/output/AST.h index 3703174..b12cee4 100644 --- a/include/anthem/output/AST.h +++ b/include/anthem/output/AST.h @@ -169,7 +169,7 @@ inline output::ColorStream &print(output::ColorStream &stream, const Comparison inline output::ColorStream &print(output::ColorStream &stream, const Function &function, PrintContext &printContext, bool) { - stream << function.name; + stream << function.declaration->name; if (function.arguments.empty()) return stream; @@ -184,7 +184,7 @@ inline output::ColorStream &print(output::ColorStream &stream, const Function &f print(stream, *i, printContext); } - if (function.name.empty() && function.arguments.size() == 1) + if (function.declaration->name.empty() && function.arguments.size() == 1) stream << ","; stream << ")"; diff --git a/src/anthem/ASTCopy.cpp b/src/anthem/ASTCopy.cpp index e19113d..1b33198 100644 --- a/src/anthem/ASTCopy.cpp +++ b/src/anthem/ASTCopy.cpp @@ -105,7 +105,7 @@ Comparison prepareCopy(const Comparison &other) Function prepareCopy(const Function &other) { - return Function(std::string(other.name), prepareCopy(other.arguments)); + return Function(other.declaration, prepareCopy(other.arguments)); } ////////////////////////////////////////////////////////////////////////////////////////////////////