# This file is part of Patsy
# Copyright (C) 2011-2012 Nathaniel Smith <njs@pobox.com>
# See file LICENSE.txt for license information.

# This file defines the ModelDesc class, which describes a model at a high
# level, as a list of interactions of factors. It also has the code to convert
# a formula parse tree (from patsy.parse_formula) into a ModelDesc.

from patsy import PatsyError
from patsy.parse_formula import ParseNode, Token, parse_formula
from patsy.eval import EvalEnvironment, EvalFactor
from patsy.util import uniqueify_list
from patsy.util import repr_pretty_delegate, repr_pretty_impl
from patsy.util import no_pickling, assert_no_pickling

# These are made available in the patsy.* namespace
__all__ = ["Term", "ModelDesc", "INTERCEPT"]


# One might think it would make more sense for 'factors' to be a set, rather
# than a tuple-with-guaranteed-unique-entries-that-compares-like-a-set. The
# reason we do it this way is that it preserves the order that the user typed
# and is expecting, which then ends up producing nicer names in our final
# output, nicer column ordering, etc. (A similar comment applies to the
# ordering of terms in ModelDesc objects as a whole.)
class Term(object):
    """The interaction between a collection of factor objects.

    This is one of the basic types used in representing formulas, and
    corresponds to an expression like ``"a:b:c"`` in a formula string.
    For details, see :ref:`formulas` and :ref:`expert-model-specification`.

    Terms are hashable and compare by value.

    Attributes:

    .. attribute:: factors

       A tuple of factor objects.
    """

    def __init__(self, factors):
        self.factors = tuple(uniqueify_list(factors))

    def __eq__(self, other):
        return isinstance(other, Term) and frozenset(other.factors) == frozenset(
            self.factors
        )

    def __ne__(self, other):
        return not self == other

    def __hash__(self):
        return hash((Term, frozenset(self.factors)))

    __repr__ = repr_pretty_delegate

    def _repr_pretty_(self, p, cycle):
        assert not cycle
        repr_pretty_impl(p, self, [list(self.factors)])

    def name(self):
        """Return a human-readable name for this term."""
        if self.factors:
            return ":".join([f.name() for f in self.factors])
        else:
            return "Intercept"

    __getstate__ = no_pickling


INTERCEPT = Term([])


class _MockFactor(object):
    def __init__(self, name):
        self._name = name

    def name(self):
        return self._name


def test_Term():
    assert Term([1, 2, 1]).factors == (1, 2)
    assert Term([1, 2]) == Term([2, 1])
    assert hash(Term([1, 2])) == hash(Term([2, 1]))
    f1 = _MockFactor("a")
    f2 = _MockFactor("b")
    assert Term([f1, f2]).name() == "a:b"
    assert Term([f2, f1]).name() == "b:a"
    assert Term([]).name() == "Intercept"

    assert_no_pickling(Term([]))


class ModelDesc(object):
    """A simple container representing the termlists parsed from a formula.

    This is a simple container object which has exactly the same
    representational power as a formula string, but is a Python object
    instead. You can construct one by hand, and pass it to functions like
    :func:`dmatrix` or :func:`incr_dbuilder` that are expecting a formula
    string, but without having to do any messy string manipulation. For
    details see :ref:`expert-model-specification`.

    Attributes:

    .. attribute:: lhs_termlist
                   rhs_termlist

       Two termlists representing the left- and right-hand sides of a
       formula, suitable for passing to :func:`design_matrix_builders`.
    """

    def __init__(self, lhs_termlist, rhs_termlist):
        self.lhs_termlist = uniqueify_list(lhs_termlist)
        self.rhs_termlist = uniqueify_list(rhs_termlist)

    __repr__ = repr_pretty_delegate

    def _repr_pretty_(self, p, cycle):
        assert not cycle
        return repr_pretty_impl(
            p,
            self,
            [],
            [("lhs_termlist", self.lhs_termlist), ("rhs_termlist", self.rhs_termlist)],
        )

    def describe(self):
        """Returns a human-readable representation of this :class:`ModelDesc`
        in pseudo-formula notation.

        .. warning:: There is no guarantee that the strings returned by this
           function can be parsed as formulas. They are best-effort
           descriptions intended for human users. However, if this ModelDesc
           was created by parsing a formula, then it should work in
           practice. If you *really* have to.
        """

        def term_code(term):
            if term == INTERCEPT:
                return "1"
            else:
                return term.name()

        result = " + ".join([term_code(term) for term in self.lhs_termlist])
        if result:
            result += " ~ "
        else:
            result += "~ "
        if self.rhs_termlist == [INTERCEPT]:
            result += term_code(INTERCEPT)
        else:
            term_names = []
            if INTERCEPT not in self.rhs_termlist:
                term_names.append("0")
            term_names += [
                term_code(term) for term in self.rhs_termlist if term != INTERCEPT
            ]
            result += " + ".join(term_names)
        return result

    @classmethod
    def from_formula(cls, tree_or_string):
        """Construct a :class:`ModelDesc` from a formula string.

        :arg tree_or_string: A formula string. (Or an unevaluated formula
          parse tree, but the API for generating those isn't public yet. Shh,
          it can be our secret.)
        :returns: A new :class:`ModelDesc`.
        """
        if isinstance(tree_or_string, ParseNode):
            tree = tree_or_string
        else:
            tree = parse_formula(tree_or_string)
        value = Evaluator().eval(tree, require_evalexpr=False)
        assert isinstance(value, cls)
        return value

    __getstate__ = no_pickling


def test_ModelDesc():
    f1 = _MockFactor("a")
    f2 = _MockFactor("b")
    m = ModelDesc([INTERCEPT, Term([f1])], [Term([f1]), Term([f1, f2])])
    assert m.lhs_termlist == [INTERCEPT, Term([f1])]
    assert m.rhs_termlist == [Term([f1]), Term([f1, f2])]
    print(m.describe())
    assert m.describe() == "1 + a ~ 0 + a + a:b"

    assert_no_pickling(m)

    assert ModelDesc([], []).describe() == "~ 0"
    assert ModelDesc([INTERCEPT], []).describe() == "1 ~ 0"
    assert ModelDesc([INTERCEPT], [INTERCEPT]).describe() == "1 ~ 1"
    assert ModelDesc([INTERCEPT], [INTERCEPT, Term([f2])]).describe() == "1 ~ b"


def test_ModelDesc_from_formula():
    for input in ("y ~ x", parse_formula("y ~ x")):
        md = ModelDesc.from_formula(input)
        assert md.lhs_termlist == [
            Term([EvalFactor("y")]),
        ]
        assert md.rhs_termlist == [INTERCEPT, Term([EvalFactor("x")])]


class IntermediateExpr(object):
    "This class holds an intermediate result while we're evaluating a tree."

    def __init__(self, intercept, intercept_origin, intercept_removed, terms):
        self.intercept = intercept
        self.intercept_origin = intercept_origin
        self.intercept_removed = intercept_removed
        self.terms = tuple(uniqueify_list(terms))
        if self.intercept:
            assert self.intercept_origin
        assert not (self.intercept and self.intercept_removed)

    __repr__ = repr_pretty_delegate

    def _pretty_repr_(self, p, cycle):  # pragma: no cover
        assert not cycle
        return repr_pretty_impl(
            p,
            self,
            [self.intercept, self.intercept_origin, self.intercept_removed, self.terms],
        )

    __getstate__ = no_pickling


def _maybe_add_intercept(doit, terms):
    if doit:
        return (INTERCEPT,) + terms
    else:
        return terms


def _eval_any_tilde(evaluator, tree):
    exprs = [evaluator.eval(arg) for arg in tree.args]
    if len(exprs) == 1:
        # Formula was like: "~ foo"
        # We pretend that instead it was like: "0 ~ foo"
        exprs.insert(0, IntermediateExpr(False, None, True, []))
    assert len(exprs) == 2
    # Note that only the RHS gets an implicit intercept:
    return ModelDesc(
        _maybe_add_intercept(exprs[0].intercept, exprs[0].terms),
        _maybe_add_intercept(not exprs[1].intercept_removed, exprs[1].terms),
    )


def _eval_binary_plus(evaluator, tree):
    left_expr = evaluator.eval(tree.args[0])
    if tree.args[1].type == "ZERO":
        return IntermediateExpr(False, None, True, left_expr.terms)
    else:
        right_expr = evaluator.eval(tree.args[1])
        if right_expr.intercept:
            return IntermediateExpr(
                True,
                right_expr.intercept_origin,
                False,
                left_expr.terms + right_expr.terms,
            )
        else:
            return IntermediateExpr(
                left_expr.intercept,
                left_expr.intercept_origin,
                left_expr.intercept_removed,
                left_expr.terms + right_expr.terms,
            )


def _eval_binary_minus(evaluator, tree):
    left_expr = evaluator.eval(tree.args[0])
    if tree.args[1].type == "ZERO":
        return IntermediateExpr(True, tree.args[1], False, left_expr.terms)
    elif tree.args[1].type == "ONE":
        return IntermediateExpr(False, None, True, left_expr.terms)
    else:
        right_expr = evaluator.eval(tree.args[1])
        terms = [term for term in left_expr.terms if term not in right_expr.terms]
        if right_expr.intercept:
            return IntermediateExpr(False, None, True, terms)
        else:
            return IntermediateExpr(
                left_expr.intercept,
                left_expr.intercept_origin,
                left_expr.intercept_removed,
                terms,
            )


def _check_interactable(expr):
    if expr.intercept:
        raise PatsyError(
            "intercept term cannot interact with " "anything else",
            expr.intercept_origin,
        )


def _interaction(left_expr, right_expr):
    for expr in (left_expr, right_expr):
        _check_interactable(expr)
    terms = []
    for l_term in left_expr.terms:
        for r_term in right_expr.terms:
            terms.append(Term(l_term.factors + r_term.factors))
    return IntermediateExpr(False, None, False, terms)


def _eval_binary_prod(evaluator, tree):
    exprs = [evaluator.eval(arg) for arg in tree.args]
    return IntermediateExpr(
        False, None, False, exprs[0].terms + exprs[1].terms + _interaction(*exprs).terms
    )


# Division (nesting) is right-ward distributive:
#   a / (b + c) -> a/b + a/c -> a + a:b + a:c
# But left-ward, in S/R it has a quirky behavior:
#   (a + b)/c -> a + b + a:b:c
# This is because it's meaningless for a factor to be "nested" under two
# different factors. (This is documented in Chambers and Hastie (page 30) as a
# "Slightly more subtle..." rule, with no further elaboration. Hopefully we
# will do better.)
def _eval_binary_div(evaluator, tree):
    left_expr = evaluator.eval(tree.args[0])
    right_expr = evaluator.eval(tree.args[1])
    terms = list(left_expr.terms)
    _check_interactable(left_expr)
    # Build a single giant combined term for everything on the left:
    left_factors = []
    for term in left_expr.terms:
        left_factors += list(term.factors)
    left_combined_expr = IntermediateExpr(False, None, False, [Term(left_factors)])
    # Then interact it with everything on the right:
    terms += list(_interaction(left_combined_expr, right_expr).terms)
    return IntermediateExpr(False, None, False, terms)


def _eval_binary_interact(evaluator, tree):
    exprs = [evaluator.eval(arg) for arg in tree.args]
    return _interaction(*exprs)


def _eval_binary_power(evaluator, tree):
    left_expr = evaluator.eval(tree.args[0])
    _check_interactable(left_expr)
    power = -1
    if tree.args[1].type in ("ONE", "NUMBER"):
        expr = tree.args[1].token.extra
        try:
            power = int(expr)
        except ValueError:
            pass
    if power < 1:
        raise PatsyError("'**' requires a positive integer", tree.args[1])
    all_terms = left_expr.terms
    big_expr = left_expr
    # Small optimization: (a + b)**100 is just the same as (a + b)**2.
    power = min(len(left_expr.terms), power)
    for i in range(1, power):
        big_expr = _interaction(left_expr, big_expr)
        all_terms = all_terms + big_expr.terms
    return IntermediateExpr(False, None, False, all_terms)


def _eval_unary_plus(evaluator, tree):
    return evaluator.eval(tree.args[0])


def _eval_unary_minus(evaluator, tree):
    if tree.args[0].type == "ZERO":
        return IntermediateExpr(True, tree.origin, False, [])
    elif tree.args[0].type == "ONE":
        return IntermediateExpr(False, None, True, [])
    else:
        raise PatsyError("Unary minus can only be applied to 1 or 0", tree)


def _eval_zero(evaluator, tree):
    return IntermediateExpr(False, None, True, [])


def _eval_one(evaluator, tree):
    return IntermediateExpr(True, tree.origin, False, [])


def _eval_number(evaluator, tree):
    raise PatsyError("numbers besides '0' and '1' are " "only allowed with **", tree)


def _eval_python_expr(evaluator, tree):
    factor = EvalFactor(tree.token.extra, origin=tree.origin)
    return IntermediateExpr(False, None, False, [Term([factor])])


class Evaluator(object):
    def __init__(self):
        self._evaluators = {}
        self.add_op("~", 2, _eval_any_tilde)
        self.add_op("~", 1, _eval_any_tilde)

        self.add_op("+", 2, _eval_binary_plus)
        self.add_op("-", 2, _eval_binary_minus)
        self.add_op("*", 2, _eval_binary_prod)
        self.add_op("/", 2, _eval_binary_div)
        self.add_op(":", 2, _eval_binary_interact)
        self.add_op("**", 2, _eval_binary_power)

        self.add_op("+", 1, _eval_unary_plus)
        self.add_op("-", 1, _eval_unary_minus)

        self.add_op("ZERO", 0, _eval_zero)
        self.add_op("ONE", 0, _eval_one)
        self.add_op("NUMBER", 0, _eval_number)
        self.add_op("PYTHON_EXPR", 0, _eval_python_expr)

        # Not used by Patsy -- provided for the convenience of eventual
        # user-defined operators.
        self.stash = {}

    # This should not be considered a public API yet (to use for actually
    # adding new operator semantics) because I wrote in some of the relevant
    # code sort of speculatively, but it isn't actually tested.
    def add_op(self, op, arity, evaluator):
        self._evaluators[op, arity] = evaluator

    def eval(self, tree, require_evalexpr=True):
        result = None
        assert isinstance(tree, ParseNode)
        key = (tree.type, len(tree.args))
        if key not in self._evaluators:
            raise PatsyError(
                "I don't know how to evaluate this " "'%s' operator" % (tree.type,),
                tree.token,
            )
        result = self._evaluators[key](self, tree)
        if require_evalexpr and not isinstance(result, IntermediateExpr):
            if isinstance(result, ModelDesc):
                raise PatsyError(
                    "~ can only be used once, and " "only at the top level", tree
                )
            else:
                raise PatsyError(
                    "custom operator returned an "
                    "object that I don't know how to "
                    "handle",
                    tree,
                )
        return result


#############

_eval_tests = {
    "": (True, []),
    " ": (True, []),
    " \n ": (True, []),
    "a": (True, ["a"]),
    "1": (True, []),
    "0": (False, []),
    "- 1": (False, []),
    "- 0": (True, []),
    "+ 1": (True, []),
    "+ 0": (False, []),
    "0 + 1": (True, []),
    "1 + 0": (False, []),
    "1 - 0": (True, []),
    "0 - 1": (False, []),
    "1 + a": (True, ["a"]),
    "0 + a": (False, ["a"]),
    "a - 1": (False, ["a"]),
    "a - 0": (True, ["a"]),
    "1 - a": (True, []),
    "a + b": (True, ["a", "b"]),
    "(a + b)": (True, ["a", "b"]),
    "a + ((((b))))": (True, ["a", "b"]),
    "a + ((((+b))))": (True, ["a", "b"]),
    "a + ((((b - a))))": (True, ["a", "b"]),
    "a + a + a": (True, ["a"]),
    "a + (b - a)": (True, ["a", "b"]),
    "a + np.log(a, base=10)": (True, ["a", "np.log(a, base=10)"]),
    # Note different spacing:
    "a + np.log(a, base=10) - np . log(a , base = 10)": (True, ["a"]),
    "a + (I(b) + c)": (True, ["a", "I(b)", "c"]),
    "a + I(b + c)": (True, ["a", "I(b + c)"]),
    "a:b": (True, [("a", "b")]),
    "a:b:a": (True, [("a", "b")]),
    "a:(b + c)": (True, [("a", "b"), ("a", "c")]),
    "(a + b):c": (True, [("a", "c"), ("b", "c")]),
    "a:(b - c)": (True, [("a", "b")]),
    "c + a:c + a:(b - c)": (True, ["c", ("a", "c"), ("a", "b")]),
    "(a - b):c": (True, [("a", "c")]),
    "b + b:c + (a - b):c": (True, ["b", ("b", "c"), ("a", "c")]),
    "a:b - a:b": (True, []),
    "a:b - b:a": (True, []),
    "1 - (a + b)": (True, []),
    "a + b - (a + b)": (True, []),
    "a * b": (True, ["a", "b", ("a", "b")]),
    "a * b * a": (True, ["a", "b", ("a", "b")]),
    "a * (b + c)": (True, ["a", "b", "c", ("a", "b"), ("a", "c")]),
    "(a + b) * c": (True, ["a", "b", "c", ("a", "c"), ("b", "c")]),
    "a * (b - c)": (True, ["a", "b", ("a", "b")]),
    "c + a:c + a * (b - c)": (True, ["c", ("a", "c"), "a", "b", ("a", "b")]),
    "(a - b) * c": (True, ["a", "c", ("a", "c")]),
    "b + b:c + (a - b) * c": (True, ["b", ("b", "c"), "a", "c", ("a", "c")]),
    "a/b": (True, ["a", ("a", "b")]),
    "(a + b)/c": (True, ["a", "b", ("a", "b", "c")]),
    "b + b:c + (a - b)/c": (True, ["b", ("b", "c"), "a", ("a", "c")]),
    "a/(b + c)": (True, ["a", ("a", "b"), ("a", "c")]),
    "a ** 2": (True, ["a"]),
    "(a + b + c + d) ** 2": (
        True,
        [
            "a",
            "b",
            "c",
            "d",
            ("a", "b"),
            ("a", "c"),
            ("a", "d"),
            ("b", "c"),
            ("b", "d"),
            ("c", "d"),
        ],
    ),
    "(a + b + c + d) ** 3": (
        True,
        [
            "a",
            "b",
            "c",
            "d",
            ("a", "b"),
            ("a", "c"),
            ("a", "d"),
            ("b", "c"),
            ("b", "d"),
            ("c", "d"),
            ("a", "b", "c"),
            ("a", "b", "d"),
            ("a", "c", "d"),
            ("b", "c", "d"),
        ],
    ),
    "a + +a": (True, ["a"]),
    "~ a + b": (True, ["a", "b"]),
    "~ a*b": (True, ["a", "b", ("a", "b")]),
    "~ a*b + 0": (False, ["a", "b", ("a", "b")]),
    "~ -1": (False, []),
    "0 ~ a + b": (True, ["a", "b"]),
    "1 ~ a + b": (True, [], True, ["a", "b"]),
    "y ~ a + b": (False, ["y"], True, ["a", "b"]),
    "0 + y ~ a + b": (False, ["y"], True, ["a", "b"]),
    "0 + y * z ~ a + b": (False, ["y", "z", ("y", "z")], True, ["a", "b"]),
    "-1 ~ 1": (False, [], True, []),
    "1 + y ~ a + b": (True, ["y"], True, ["a", "b"]),
    # Check precedence:
    "a + b * c": (True, ["a", "b", "c", ("b", "c")]),
    "a * b + c": (True, ["a", "b", ("a", "b"), "c"]),
    "a * b - a": (True, ["b", ("a", "b")]),
    "a + b / c": (True, ["a", "b", ("b", "c")]),
    "a / b + c": (True, ["a", ("a", "b"), "c"]),
    "a*b:c": (True, ["a", ("b", "c"), ("a", "b", "c")]),
    "a:b*c": (True, [("a", "b"), "c", ("a", "b", "c")]),
    # Intercept handling:
    "~ 1 + 1 + 0 + 1": (True, []),
    "~ 0 + 1 + 0": (False, []),
    "~ 0 - 1 - 1 + 0 + 1": (True, []),
    "~ 1 - 1": (False, []),
    "~ 0 + a + 1": (True, ["a"]),
    "~ 1 + (a + 0)": (True, ["a"]),  # This is correct, but perhaps surprising!
    "~ 0 + (a + 1)": (True, ["a"]),  # Also correct!
    "~ 1 - (a + 1)": (False, []),
}

# <> mark off where the error should be reported:
_eval_error_tests = [
    "a <+>",
    "a + <(>",
    "b + <(-a)>",
    "a:<1>",
    "(a + <1>)*b",
    "a + <2>",
    "a + <1.0>",
    # eh, catching this is a hassle, we'll just leave the user some rope if
    # they really want it:
    # "a + <0x1>",
    "a ** <b>",
    "a ** <(1 + 1)>",
    "a ** <1.5>",
    "a + b <# asdf>",
    "<)>",
    "a + <)>",
    "<*> a",
    "a + <*>",
    "a + <foo[bar>",
    "a + <foo{bar>",
    "a + <foo(bar>",
    "a + <[bar>",
    "a + <{bar>",
    "a + <{bar[]>",
    "a + foo<]>bar",
    "a + foo[]<]>bar",
    "a + foo{}<}>bar",
    "a + foo<)>bar",
    "a + b<)>",
    "(a) <.>",
    "<(>a + b",
    "<y ~ a> ~ b",
    "y ~ <(a ~ b)>",
    "<~ a> ~ b",
    "~ <(a ~ b)>",
    "1 + <-(a + b)>",
    "<- a>",
    "a + <-a**2>",
]


def _assert_terms_match(terms, expected_intercept, expecteds):  # pragma: no cover
    if expected_intercept:
        expecteds = [()] + expecteds
    assert len(terms) == len(expecteds)
    for term, expected in zip(terms, expecteds):
        if isinstance(term, Term):
            if isinstance(expected, str):
                expected = (expected,)
            assert term.factors == tuple([EvalFactor(s) for s in expected])
        else:
            assert term == expected


def _do_eval_formula_tests(tests):  # pragma: no cover
    for code, result in tests.items():
        if len(result) == 2:
            result = (False, []) + result
        model_desc = ModelDesc.from_formula(code)
        print(repr(code))
        print(result)
        print(model_desc)
        lhs_intercept, lhs_termlist, rhs_intercept, rhs_termlist = result
        _assert_terms_match(model_desc.lhs_termlist, lhs_intercept, lhs_termlist)
        _assert_terms_match(model_desc.rhs_termlist, rhs_intercept, rhs_termlist)


def test_eval_formula():
    _do_eval_formula_tests(_eval_tests)


def test_eval_formula_error_reporting():
    from patsy.parse_formula import _parsing_error_test

    parse_fn = lambda formula: ModelDesc.from_formula(formula)
    _parsing_error_test(parse_fn, _eval_error_tests)


def test_formula_factor_origin():
    from patsy.origin import Origin

    desc = ModelDesc.from_formula("a + b")
    assert desc.rhs_termlist[1].factors[0].origin == Origin("a + b", 0, 1)
    assert desc.rhs_termlist[2].factors[0].origin == Origin("a + b", 4, 5)
