diff --git a/src/format_tptp.rs b/src/format_tptp.rs index 03556c0..0d059a1 100644 --- a/src/format_tptp.rs +++ b/src/format_tptp.rs @@ -25,7 +25,66 @@ fn is_arithmetic_term(term: &foliage::Term) -> bool } } +fn collect_predicate_declarations_in_formula<'a>( + predicate_declarations: &mut std::collections::HashSet<&'a foliage::PredicateDeclaration>, formula: &'a foliage::Formula) +{ + match formula + { + foliage::Formula::Exists(ref exists) => collect_predicate_declarations_in_formula(predicate_declarations, &exists.argument), + foliage::Formula::ForAll(ref for_all) => collect_predicate_declarations_in_formula(predicate_declarations, &for_all.argument), + foliage::Formula::Not(ref argument) => collect_predicate_declarations_in_formula(predicate_declarations, argument), + foliage::Formula::And(ref arguments) => + for argument in arguments + { + collect_predicate_declarations_in_formula(predicate_declarations, argument); + }, + foliage::Formula::Or(ref arguments) => + for argument in arguments + { + collect_predicate_declarations_in_formula(predicate_declarations, argument); + }, + foliage::Formula::Implies(ref left, ref right) => + { + collect_predicate_declarations_in_formula(predicate_declarations, left); + collect_predicate_declarations_in_formula(predicate_declarations, right); + }, + foliage::Formula::Biconditional(ref left, ref right) => + { + collect_predicate_declarations_in_formula(predicate_declarations, left); + collect_predicate_declarations_in_formula(predicate_declarations, right); + }, + foliage::Formula::Less(_, _) => (), + foliage::Formula::LessOrEqual(_, _) => (), + foliage::Formula::Greater(_, _) => (), + foliage::Formula::GreaterOrEqual(_, _) => (), + foliage::Formula::Equal(_, _) => (), + foliage::Formula::NotEqual(_, _) => (), + foliage::Formula::Boolean(_) => (), + foliage::Formula::Predicate(ref predicate) => + { + predicate_declarations.insert(&predicate.declaration); + }, + } +} + +fn collect_predicate_declarations_in_project<'a>(project: &'a crate::Project) + -> std::collections::HashSet<&'a foliage::PredicateDeclaration> +{ + let mut predicate_declarations = std::collections::HashSet::new(); + + for (_, formulas) in project.statements.iter() + { + for formula in formulas.iter() + { + collect_predicate_declarations_in_formula(&mut predicate_declarations, formula); + } + } + + predicate_declarations +} + struct VariableDeclarationDisplay<'a>(&'a foliage::VariableDeclaration); +struct PredicateDeclarationDisplay<'a>(&'a foliage::PredicateDeclaration); struct TermDisplay<'a>(&'a foliage::Term); struct FormulaDisplay<'a>(&'a foliage::Formula); struct StatementKindDisplay<'a>(&'a crate::project::StatementKind); @@ -39,6 +98,14 @@ impl<'a> DisplayTPTP<'a, VariableDeclarationDisplay<'a>> for foliage::VariableDe } } +impl<'a> DisplayTPTP<'a, PredicateDeclarationDisplay<'a>> for foliage::PredicateDeclaration +{ + fn display_tptp(&'a self) -> PredicateDeclarationDisplay<'a> + { + PredicateDeclarationDisplay(self) + } +} + impl<'a> DisplayTPTP<'a, TermDisplay<'a>> for foliage::Term { fn display_tptp(&'a self) -> TermDisplay<'a> @@ -91,6 +158,32 @@ impl<'a> std::fmt::Display for VariableDeclarationDisplay<'a> } } +impl<'a> std::fmt::Debug for PredicateDeclarationDisplay<'a> +{ + fn fmt(&self, format: &mut std::fmt::Formatter) -> std::fmt::Result + { + write!(format, "{}: (", self.0.name)?; + + let mut separator = ""; + + for _ in 0..self.0.arity + { + write!(format, "{}object", separator)?; + separator = " * " + } + + write!(format, ") > $o") + } +} + +impl<'a> std::fmt::Display for PredicateDeclarationDisplay<'a> +{ + fn fmt(&self, format: &mut std::fmt::Formatter) -> std::fmt::Result + { + write!(format, "{:?}", &self) + } +} + impl<'a> std::fmt::Debug for TermDisplay<'a> { fn fmt(&self, format: &mut std::fmt::Formatter) -> std::fmt::Result @@ -305,41 +398,41 @@ impl<'a> std::fmt::Debug for ProjectDisplay<'a> { fn fmt(&self, format: &mut std::fmt::Formatter) -> std::fmt::Result { - let mut line_separator = ""; - let mut section_separator = ""; + write!(format, "tff(types, type, object: $tType).")?; + + let predicate_declarations = collect_predicate_declarations_in_project(self.0); + + if !predicate_declarations.is_empty() + { + for predicate_declaration in predicate_declarations + { + write!(format, "\ntff(type, type, {:?}).", predicate_declaration.display_tptp())?; + } + } if let Some(axioms) = self.0.statements.get(&crate::project::StatementKind::Axiom) { + write!(format, "\n")?; + for axiom in axioms { - write!(format, "{}tff({:?}, {:?}).", line_separator, crate::project::StatementKind::Axiom.display_tptp(), axiom.display_tptp())?; - line_separator = "\n"; + write!(format, "\ntff({:?}, {:?}).", crate::project::StatementKind::Axiom.display_tptp(), axiom.display_tptp())?; } - - section_separator = "\n"; } if let Some(lemmas) = self.0.statements.get(&crate::project::StatementKind::Lemma) { - write!(format, "{}", section_separator)?; - for lemma in lemmas { - write!(format, "{}tff({:?}, {:?}).", line_separator, crate::project::StatementKind::Lemma.display_tptp(), lemma.display_tptp())?; - line_separator = "\n"; + write!(format, "\ntff({:?}, {:?}).", crate::project::StatementKind::Lemma.display_tptp(), lemma.display_tptp())?; } - - section_separator = "\n"; } if let Some(conjectures) = self.0.statements.get(&crate::project::StatementKind::Conjecture) { - write!(format, "{}", section_separator)?; - for conjecture in conjectures { - write!(format, "{}tff({:?}, {:?}).", line_separator, crate::project::StatementKind::Conjecture.display_tptp(), conjecture.display_tptp())?; - line_separator = "\n"; + write!(format, "\ntff({:?}, {:?}).", crate::project::StatementKind::Conjecture.display_tptp(), conjecture.display_tptp())?; } }