/*++
Copyright (c) 2006 Microsoft Corporation

Module Name:

    theory_datatype.h

Abstract:

    <abstract>

Author:

    Leonardo de Moura (leonardo) 2008-10-31.

Revision History:

--*/
#pragma once

#include "util/union_find.h"
#include "ast/array_decl_plugin.h"
#include "ast/seq_decl_plugin.h"
#include "ast/datatype_decl_plugin.h"
#include "model/datatype_factory.h"
#include "smt/smt_theory.h"
#include "params/theory_datatype_params.h"

namespace smt {
    class theory_datatype : public theory {
        typedef union_find<theory_datatype>  th_union_find;

        struct var_data {
            ptr_vector<enode> m_recognizers; //!< recognizers of this equivalence class that are being watched.
            enode *           m_constructor; //!< constructor of this equivalence class, 0 if there is no constructor in the eqc.

            /**
             *  \brief subterm predicates that involve this equivalence class
             * 
             * So all terms of the shape `a ⊑ b` where `var_data` represents either `a` or `b`.
             * 
             * This is more a set than a vector, but I'll use `ptr_vector`
             * because I know the API better, it's easier to backtrack on it and
             * it should be small enough to outperform a hasmap anyway
            */
            ptr_vector<enode> m_subterms;

            var_data():
                m_constructor(nullptr) {
            }
        };

        struct stats {
            unsigned   m_occurs_check, m_splits;
            unsigned   m_assert_cnstr, m_assert_accessor, m_assert_update_field;
            void reset() { memset(this, 0, sizeof(stats)); }
            stats() { reset(); }
        };
        
        datatype_util             m_util;
        array_util                m_autil;
        seq_util                  m_sutil;
        ptr_vector<var_data>      m_var_data;
        th_union_find             m_find;
        trail_stack               m_trail_stack;
        datatype_factory *        m_factory;
        stats                     m_stats;

        bool is_constructor(app * f) const { return m_util.is_constructor(f); }
        bool is_recognizer(app * f) const { return m_util.is_recognizer(f); }
        bool is_subterm_predicate(app * f) const { return m_util.is_subterm_predicate(f); }
        bool is_accessor(app * f) const { return m_util.is_accessor(f); }
        bool is_update_field(app * f) const { return m_util.is_update_field(f); }

        bool is_constructor(enode * n) const { return is_constructor(n->get_expr()); }
        bool is_recognizer(enode * n) const { return is_recognizer(n->get_expr()); }
        bool is_subterm_predicate(enode * n) const { return is_subterm_predicate(n->get_expr()); }
        bool is_accessor(enode * n) const { return is_accessor(n->get_expr()); }
        bool is_update_field(enode * n) const { return m_util.is_update_field(n->get_expr()); }

        void assert_eq_axiom(enode * lhs, expr * rhs, literal antecedent);
        void assert_is_constructor_axiom(enode * n, func_decl * c, literal antecedent);
        void assert_accessor_axioms(enode * n);
        void assert_update_field_axioms(enode * n);
        void assert_subterm_axioms(enode * n);
        void add_recognizer(theory_var v, enode * recognizer);
        void add_subterm_predicate(theory_var v, enode *predicate);
        void propagate_subterm(enode * n, bool is_true);
        void propagate_is_subterm(enode * n);
        void propagate_not_is_subterm(enode *n);
        void split_leaf_root(smt::enode *arg2);
        void propagate_subterm_with_constructor(theory_var v);
        void propagate_recognizer(theory_var v, enode *r);
        void sign_recognizer_conflict(enode * c, enode * r);

        typedef enum { ENTER, EXIT } stack_op;
        typedef obj_map<enode, enode*> parent_tbl;
        typedef std::pair<stack_op, enode*> stack_entry;

        ptr_vector<enode>     m_to_unmark;
        ptr_vector<enode>     m_to_unmark2;
        enode_pair_vector     m_used_eqs; // conflict, if any
        parent_tbl            m_parent; // parent explanation for occurs_check
        svector<stack_entry>  m_stack; // stack for DFS for occurs_check
        literal_vector        m_lits;

        void clear_mark();

        void oc_mark_on_stack(enode * n);
        bool oc_on_stack(enode * n) const { return n->get_root()->is_marked(); }

        void oc_mark_cycle_free(enode * n);
        bool oc_cycle_free(enode * n) const { return n->get_root()->is_marked2(); }

        void oc_push_stack(enode * n);
        ptr_vector<enode> m_args, m_todo;
        ptr_vector<enode> const& get_array_args(enode* n);
        ptr_vector<enode> const& get_seq_args(enode* n, enode*& sibling);

        // class for managing state of final_check
        class final_check_st {
            theory_datatype * th;
        public:
            final_check_st(theory_datatype * th);
            ~final_check_st();
        };

        enode * oc_get_cstor(enode * n);
        bool occurs_check(enode * n);
        bool occurs_check_enter(enode * n);
        void occurs_check_explain(enode * top, enode * root);
        void explain_is_child(enode* parent, enode* child);

        void mk_split(theory_var v);

        void display_var(std::ostream & out, theory_var v) const;
        ptr_vector<enode> list_subterms(enode* arg);

    protected:
        theory_var mk_var(enode * n) override;
        bool internalize_atom(app * atom, bool gate_ctx) override;
        bool internalize_term(app * term) override;
        void apply_sort_cnstr(enode * n, sort * s) override;
        void new_eq_eh(theory_var v1, theory_var v2) override;
        bool use_diseqs() const override;
        void new_diseq_eh(theory_var v1, theory_var v2) override;
        void assign_eh(bool_var v, bool is_true) override;
        void relevant_eh(app * n) override;
        void push_scope_eh() override;
        void pop_scope_eh(unsigned num_scopes) override;
        final_check_status final_check_eh(unsigned) override;
        void reset_eh() override;
        void restart_eh() override { m_util.reset(); }
        bool is_shared(theory_var v) const override;
        theory_datatype_params const& params() const;
        struct iterator_factory;
        struct subterm_iterator {
            iterator_factory &f;
            ptr_vector<enode> m_todo;
            enode *m_current = nullptr;

            void next();

            bool operator!=(const subterm_iterator &other) const { return m_current != other.m_current; }

            enode *operator*() const { return m_current; }

            void operator++() { next(); }

            subterm_iterator(iterator_factory &f, enode *start) : f(f) {
                if (start) {
                    m_todo.push_back(start);
                    next();
                }
            }
        };

        struct iterator_factory {
            theory_datatype &th;
            ptr_vector<enode> m_marked;
            enode *start;
            iterator_factory(theory_datatype &th, enode* start) : th(th), start(start) {}
            subterm_iterator begin() {
                return subterm_iterator(*this, start);
            }
            subterm_iterator end() {
                return subterm_iterator(*this, nullptr);
            }
            void reset() {
                for (enode* n : m_marked) 
                    n->unset_mark();                
                m_marked.reset();
            }
        };

        iterator_factory iterate_subterms(enode *arg) {
            return iterator_factory(*this, arg);
        }
    public:
        theory_datatype(context& ctx);
        ~theory_datatype() override;
        theory * mk_fresh(context * new_ctx) override;
        void display(std::ostream & out) const override;
        void collect_statistics(::statistics & st) const override;
        void init_model(model_generator & m) override;
        model_value_proc * mk_value(enode * n, model_generator & m) override;
        trail_stack & get_trail_stack() { return m_trail_stack; }
        virtual void merge_eh(theory_var v1, theory_var v2, theory_var, theory_var);
        static void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {}
        void unmerge_eh(theory_var v1, theory_var v2);
        char const * get_name() const override { return "datatype"; }
        bool include_func_interp(func_decl* f) override;

    };
};


