diff --git a/src/anthem/IntegerVariableDetection.cpp b/src/anthem/IntegerVariableDetection.cpp index 16b78cd..ea2932f 100644 --- a/src/anthem/IntegerVariableDetection.cpp +++ b/src/anthem/IntegerVariableDetection.cpp @@ -237,9 +237,28 @@ struct VariableDomainInFormulaVisitor return ast::Domain::Unknown; } - static ast::Domain visit(ast::Predicate &, ast::VariableDeclaration &) + static ast::Domain visit(ast::Predicate &predicate, ast::VariableDeclaration &variableDeclaration) { - // TODO: implement correctly + // TODO: check implementation for nested arguments + + // Inherit the domain of the predicate’s parameters + for (size_t i = 0; i < predicate.arguments.size(); i++) + { + auto &argument = predicate.arguments[i]; + + if (!argument.is()) + continue; + + auto &variable = argument.get(); + + if (variable.declaration != &variableDeclaration) + continue; + + auto ¶meter = predicate.declaration->parameters[i]; + + return parameter.domain; + } + return ast::Domain::Unknown; } }; @@ -291,11 +310,15 @@ struct DetectIntegerVariablesVisitor operationResult = OperationResult::Changed; for (auto &variableDeclaration : exists.variables) - if (variableDeclaration->domain == ast::Domain::Unknown - && exists.argument.accept(VariableDomainInFormulaVisitor(), *variableDeclaration) == ast::Domain::Integer) + if (variableDeclaration->domain != ast::Domain::General) { + auto newDomain = exists.argument.accept(VariableDomainInFormulaVisitor(), *variableDeclaration); + + if (variableDeclaration->domain == newDomain) + continue; + operationResult = OperationResult::Changed; - variableDeclaration->domain = ast::Domain::Integer; + variableDeclaration->domain = newDomain; } return operationResult; @@ -309,11 +332,15 @@ struct DetectIntegerVariablesVisitor operationResult = OperationResult::Changed; for (auto &variableDeclaration : forAll.variables) - if (variableDeclaration->domain == ast::Domain::Unknown - && forAll.argument.accept(VariableDomainInFormulaVisitor(), *variableDeclaration) == ast::Domain::Integer) + if (variableDeclaration->domain != ast::Domain::General) { + auto newDomain = forAll.argument.accept(VariableDomainInFormulaVisitor(), *variableDeclaration); + + if (variableDeclaration->domain == newDomain) + continue; + operationResult = OperationResult::Changed; - variableDeclaration->domain = ast::Domain::Integer; + variableDeclaration->domain = newDomain; } return operationResult; @@ -387,18 +414,42 @@ void detectIntegerVariables(std::vector &completedFormulas) if (!biconditional.left.is()) continue; + auto &predicate = biconditional.left.get(); auto &definition = biconditional.right; if (definition.accept(DetectIntegerVariablesVisitor()) == OperationResult::Changed) operationResult = OperationResult::Changed; for (auto &variableDeclaration : forAll.variables) - if (variableDeclaration->domain == ast::Domain::Unknown - && definition.accept(VariableDomainInFormulaVisitor(), *variableDeclaration) == ast::Domain::Integer) + if (variableDeclaration->domain != ast::Domain::General) { + auto newDomain = forAll.argument.accept(VariableDomainInFormulaVisitor(), *variableDeclaration); + + if (variableDeclaration->domain == newDomain) + continue; + operationResult = OperationResult::Changed; - variableDeclaration->domain = ast::Domain::Integer; + variableDeclaration->domain = newDomain; } + + assert(predicate.arguments.size() == predicate.declaration->arity()); + + // Update parameter domains + for (size_t i = 0; i < predicate.arguments.size(); i++) + { + auto &variableArgument = predicate.arguments[i]; + + assert(variableArgument.is()); + + auto &variable = variableArgument.get(); + auto ¶meter = predicate.declaration->parameters[i]; + + if (parameter.domain == variable.declaration->domain) + continue; + + operationResult = OperationResult::Changed; + parameter.domain = variable.declaration->domain; + } } } }