
from collections import defaultdict
from semmle.python import ast
from semmle.python.passes.ast_pass import iter_fields
from operator import itemgetter
MAX_SPLITS = 2

def do_split(ast_root, graph):
    ast_labels = label_ast(ast_root)
    cfg_labels = label_cfg(graph, ast_labels)
    split_points = choose_split_points(graph, cfg_labels)
    graph.split(split_points)

class ScopedAstLabellingVisitor(object):

    def __init__(self, labels):
        self.labels = labels
        self.priority = 0

    def visit(self, node):
        method = ('visit_' + node.__class__.__name__)
        getattr(self, method, self.generic_visit)(node)

    def generic_visit(self, node):
        if isinstance(node, ast.AstBase):
            for (_, _, value) in iter_fields(node):
                self.visit(value)

    def visit_Class(self, node):
        return
    visit_Function = visit_Class

    def visit_list(self, the_list):
        for item in the_list:
            method = ('visit_' + item.__class__.__name__)
            getattr(self, method, self.generic_visit)(item)

    @staticmethod
    def get_variable(expr):
        if hasattr(expr, 'variable'):
            return expr.variable
        else:
            return None

    @staticmethod
    def is_const(expr):
        if isinstance(expr, ast.Name):
            return (expr.variable.id in ('None', 'True', 'False'))
        elif isinstance(expr, ast.UnaryOp):
            return ScopedAstLabellingVisitor.is_const(expr.operand)
        return isinstance(expr, (ast.Num, ast.Str))

class AstLabeller(ScopedAstLabellingVisitor):

    def __init__(self, *args):
        ScopedAstLabellingVisitor.__init__(self, *args)
        self.in_test = 0

    def _label_for_compare(self, cmp):
        if (len(cmp.ops) != 1):
            return None
        var = self.get_variable(cmp.left)
        if (var is None):
            var = self.get_variable(cmp.comparators[0])
            k = cmp.left
        else:
            k = cmp.comparators[0]
        if ((var is not None) and self.is_const(k)):
            self.priority += 1
            return (var, k, self.priority)
        return None

    def visit_Compare(self, cmp):
        label = self._label_for_compare(cmp)
        if label:
            self.labels[cmp].append(label)

    def visit_Name(self, name):
        self.priority += 1
        if isinstance(name.ctx, ast.Store):
            self.labels[name].append((name.variable, 'assign', self.priority))
        elif self.in_test:
            self.labels[name].append((name.variable, None, self.priority))

    def _label_for_unary_operand(self, op):
        if (not isinstance(op.op, ast.Not)):
            return None
        if isinstance(op.operand, ast.UnaryOp):
            return self._label_for_unary_operand(op.operand)
        elif isinstance(op.operand, ast.Name):
            self.priority += 1
            return (op.operand.variable, None, self.priority)
        elif isinstance(op.operand, ast.Compare):
            return self._label_for_compare(op.operand)
        return None

    def visit_UnaryOp(self, op):
        if (not self.in_test):
            return
        label = self._label_for_unary_operand(op)
        if label:
            self.labels[op].append(label)
        else:
            self.visit(op.operand)

    def visit_If(self, ifstmt):
        self.in_test += 1
        self.visit(ifstmt.test)
        self.in_test -= 1
        self.visit(ifstmt.body)
        self.visit(ifstmt.orelse)
        k1 = {}
        ConstantAssignmentVisitor(k1).visit(ifstmt.body)
        k2 = {}
        ConstantAssignmentVisitor(k2).visit(ifstmt.orelse)
        k = set(k1.keys()).union(k2.keys())
        self.priority += 1
        for var in k:
            val = (k1[var] if (var in k1) else k2[var])
            self.labels[ifstmt.test].append((var, val, self.priority))

    def visit_Try(self, stmt):
        self.generic_visit(stmt)
        if ((not stmt.handlers) or (len(stmt.handlers) > 1)):
            return
        k1 = {}
        ConstantAssignmentVisitor(k1).visit(stmt.body)
        k2 = {}
        ConstantAssignmentVisitor(k2).visit(stmt.handlers[0])
        k = set(k1.keys()).union(k2.keys())
        self.priority += 1
        for var in k:
            val = (k1[var] if (var in k1) else k2[var])
            self.labels[stmt].append((var, val, self.priority))

    def visit_ClassExpr(self, node):
        self.priority += 1
        self.labels[node].append((None, 'define', self.priority))
    visit_FunctionExpr = visit_ClassExpr

class TryBodyAndHandlerVisitor(ScopedAstLabellingVisitor):

    def generic_visit(self, node):
        if isinstance(node, ast.AstBase):
            self.labels.add(node)
            for (_, _, value) in iter_fields(node):
                self.visit(value)

    def visit_ExceptStmt(self, node):
        self.labels.add(node)
        return

class ConstantAssignmentVisitor(ScopedAstLabellingVisitor):

    def visit_Assign(self, asgn):
        if (not self.is_const(asgn.value)):
            return
        for target in asgn.targets:
            if hasattr(target, 'variable'):
                self.labels[target.variable] = asgn.value

def label_ast(ast_root):
    labels = defaultdict(list)
    labeller = AstLabeller(labels)
    labeller.generic_visit(ast_root)
    return labels

def _is_branch(node, graph):
    if ((len(graph.succ[node]) == 2) or isinstance(node.node, ast.Try)):
        return True
    if (len(graph.succ[node]) != 1):
        return False
    succ = graph.succ[node][0]
    if (not isinstance(succ.node, ast.UnaryOp)):
        return False
    return _is_branch(succ, graph)

def label_cfg(graph, ast_labels):
    cfg_labels = {}
    for (node, _) in graph.nodes():
        if (node.node not in ast_labels):
            continue
        labels = ast_labels[node.node]
        if (not labels):
            continue
        if (_is_branch(node, graph) or (labels[0][1] in ('assign', 'define', 'loop'))):
            cfg_labels[node] = labels
    return cfg_labels

def usefully_comparable_types(o1, o2):
    if ((o1 is None) or (o2 is None)):
        return True
    return (type(o1) is type(o2))

def exits_from_subtree(head, subtree, graph):
    exits = set()
    seen = set()
    todo = set([head])
    while todo:
        node = todo.pop()
        if (node in seen):
            continue
        seen.add(node)
        if (not graph.succ[node]):
            continue
        is_exit = True
        for succ in graph.succ[node]:
            if (succ.node in subtree):
                todo.add(succ)
                is_exit = False
        if is_exit:
            exits.add(node)
    return exits

def get_split_heads(head, graph):
    if isinstance(head.node, ast.Try):
        try_body = set()
        TryBodyAndHandlerVisitor(try_body).visit(head.node)
        if head.node.handlers:
            try_body.add(head.node.handlers[0])
        try_split_tails = exits_from_subtree(head, try_body, graph)
        return try_split_tails
    else:
        return graph.succ[head]

def choose_split_points(graph, cfg_labels):
    candidates = []
    labels = []
    for (node, label_list) in cfg_labels.items():
        for label in label_list:
            labels.append((node, label[0], label[1], label[2]))
    labels.sort(key=itemgetter(3))
    for (first_node, first_var, first_type, first_priority) in labels:
        if (first_type in ('assign', 'define')):
            continue
        if ('define' in [type for (_, _, type, priority) in labels if (priority > first_priority)]):
            break
        for (second_node, second_var, second_type, second_priority) in labels:
            if (second_var != first_var):
                continue
            if (first_priority >= second_priority):
                continue
            if (second_type == 'assign'):
                break
            if (not graph.strictly_dominates(first_node, second_node)):
                continue
            if (not usefully_comparable_types(first_type, second_type)):
                continue
            split_heads = get_split_heads(first_node, graph)
            if (len(split_heads) != 2):
                continue
            for head in split_heads:
                if (not graph.strictly_dominates(first_node, head)):
                    break
                if (not graph.reaches_while_dominated(head, second_node, first_node)):
                    break
            else:
                candidates.append((first_node, split_heads, first_var, first_priority))
    candidates = deduplicate(candidates, 0, 3)
    if ((len(candidates) > MAX_SPLITS) and (len({c[2] for c in candidates}) > 1)):
        candidates = deduplicate(candidates, 2, 3)
    return [c[:2] for c in candidates[(MAX_SPLITS - 1)::(- 1)]]

def deduplicate(lst, col, sort_col):
    dedupped = {}
    for t in reversed(lst):
        dedupped[t[col]] = t
    return sorted(dedupped.values(), key=itemgetter(sort_col))
