
import sys
import os.path
__all__ = ['transform_ast']
from semmle.python.passes.ast_pass import iter_fields, ASTVisitor
from spitfire.compiler import util
from spitfire.compiler.ast import PlaceholderSubstitutionNode
from semmle.python import ast
import bisect
LOAD = ast.Load()
PARAM = ast.Param()
STORE = ast.Store()

def zero_location(node):
    node.lineno = 0
    node.col_offset = (- 1)
    node._end = (0, 0)

class AstTranslator(object):

    def __init__(self, strict_resolution):
        self._visit_methods = {}
        self.originals = {}
        self.filter_function = None
        self.strict_resolution = strict_resolution
        self.module_members = []
        self.class_members = []
        self.template_method = None
        self.bases = []

    def visit_method(self, node):
        if hasattr(self, ('visit_' + type(node).__name__)):
            visit_method = getattr(self, ('visit_' + type(node).__name__))
        else:
            visit_method = self.visit_generic
        self._visit_methods[type(node)] = visit_method
        return visit_method

    def visit_expr_list(self, node, ctx):
        result = []
        for item in node:
            new_node = self.visit_expr(item, ctx)
            if (new_node is not None):
                assert isinstance(new_node, (ast.expr, ast.keyword, list)), new_node
                result.append(new_node)
        return result

    def visit_stmt_list(self, node):
        result = []
        for item in node:
            new_node = self.visit_stmt(item)
            if (new_node is not None):
                assert isinstance(new_node, (ast.stmt, list)), new_node
                result.append(new_node)
        return result

    def visit_stmt(self, node):
        if isinstance(node, list):
            return self.visit_stmt_list(node)
        try:
            visit_method = self._visit_methods[type(node)]
        except KeyError:
            visit_method = self.visit_method(node)
        result = visit_method(node)
        if (isinstance(result, ast.AstBase) and (result not in self.originals)):
            self.originals[result] = node
        return result

    def visit_expr(self, node, ctx):
        if isinstance(node, list):
            return self.visit_expr_list(node, ctx)
        try:
            visit_method = self._visit_methods[type(node)]
        except KeyError:
            visit_method = self.visit_method(node)
        result = visit_method(node, ctx)
        if (isinstance(result, ast.AstBase) and (result not in self.originals)):
            self.originals[result] = node
        assert (result is not None), (type(node), node)
        return result

    def transform(self, name, node):
        result = self.visit_expr(node, name)
        flatten_lists(result)
        return result

    def visit_generic(self, node, ctx=None):
        raise NotImplementedError(('Unimplemented visitor for ' + type(node).__name__))

    def visit_TemplateNode(self, node, name):
        assert (not node.import_nodes)
        assert (not node.from_nodes)
        assert (not node.extends_nodes)
        assert (not node.attr_nodes)
        node.pos = (- 1)
        self.module_scope = ast.Module(self.module_members)
        zero_location(self.module_scope)
        self._make_import(node, ['spitfire', 'runtime'])
        self._make_import(node, ['spitfire', 'runtime', 'template'])
        clsdef = self.create_class_def(name.split('.')[(- 1)], self.bases, None, self.class_members, self.module_scope)
        self.class_scope = clsdef.value.inner_scope
        args = []
        stmts = []
        funcexpr = self.create_func_expr('', args, ast.arguments([], [], [], None, None, []), None, None, None, None, stmts)
        self.function_scope = funcexpr.inner_scope
        args.append(self.make_name('self', PARAM))
        stmts.extend(self.visit_stmt(node.child_nodes))
        if (not self.bases):
            spitfire = self.make_name('spitfire', LOAD, self.module_scope)
            zero_location(spitfire)
            self.runtime = ast.Attribute(spitfire, 'runtime', LOAD)
            zero_location(self.runtime)
            template = ast.Attribute(self.runtime, 'template', LOAD)
            zero_location(template)
            base = ast.Attribute(template, 'SpitfireTemplate', LOAD)
            zero_location(base)
            self.bases.append(base)
        self.module_members.append(clsdef)
        zero_location(clsdef)
        zero_location(clsdef.value)
        if self.template_method:
            func_name = self.template_method
        else:
            func_name = 'main'
        if (func_name != 'library'):
            funcexpr.inner_scope.name = func_name
            funcexpr.name = func_name
            top_level_defn = ast.Assign(funcexpr, [self.make_name(func_name, STORE, self.class_scope)])
            self.function_scope.name = func_name
            zero_location(top_level_defn)
            zero_location(top_level_defn.value)
            self.class_members.append(top_level_defn)
        return self.module_scope

    def visit_ExtendsNode(self, node):
        node_list = node.module_name_list
        ast_node = self.make_name(node_list[0], LOAD, self.module_scope)
        for n in node_list[1:]:
            ast_node = ast.Attribute(ast_node, n.name, LOAD)
            self.originals[ast_node] = n
        self.bases.append(ast_node)
    visit_AbsoluteExtendsNode = visit_ExtendsNode

    def visit_LooseResolutionNode(self, node):
        self.strict_resolution = False
        self.visit_stmt(node.child_nodes)

    def visit_ImplementsNode(self, node):
        self.template_method = node.name

    def _make_import(self, node, name_list):
        imp = self.make_import_expr(node, name_list)
        self.originals[imp] = node
        alias = ast.alias(imp, self.make_name(name_list[0], STORE, self.module_scope))
        result = ast.Import([alias])
        self.originals[alias] = node
        self.originals[result] = node
        self.module_members.append(result)

    def visit_ImportNode(self, node):
        self._make_import(node, [ident.name for ident in node.module_name_list])

    def visit_FromNode(self, node):
        imp = self.make_import_expr(node, [ident.name for ident in node.module_name_list])
        self.originals[imp] = node
        expr = ast.ImportMember(imp, node.identifier.name)
        asname = self.make_name(node.identifier, STORE, self.module_scope)
        alias = ast.alias(expr, asname)
        result = ast.Import([alias])
        self.originals[alias] = node
        self.originals[result] = node
        self.module_members.append(result)

    def visit_GlobalNode(self, node):
        glob = ast.Global([node.name])
        self.originals[glob] = node
        return glob

    def make_import_expr(self, node, name_list):
        module_name = '.'.join(name_list)
        imp = ast.ImportExpr(0, module_name, True)
        return imp

    def visit_ForNode(self, node):
        target = self.visit_expr(node.target_list, STORE)
        iter = self.visit_expr(node.expression_list, LOAD)
        body = self.visit_stmt(node.child_nodes)
        forstmt = ast.For(target, iter, body, None)
        return forstmt

    def visit_FilterAttributeNode(self, node):
        asgn = self._make_attribute(node, True)
        self.filter_function = node.default
        return asgn

    def visit_AttributeNode(self, node):
        return self._make_attribute(node, False)

    def _make_attribute(self, node, is_filter):
        lhs = self.make_name(node, STORE, self.class_scope)
        rhs = self.visit_expr(node.default, LOAD)
        if is_filter:
            rhs = self.wrap_in_call(rhs, node, 'staticmethod', self.module_scope)
        asgn = ast.Assign(rhs, [lhs])
        self.originals[asgn] = node
        self.class_members.append(asgn)

    def visit_CallFunctionNode(self, node, ctx):
        assert (ctx == LOAD), ctx
        func = self.visit_expr(node.expression, ctx)
        args = self.visit_expr(node.arg_list, ctx)
        call = make_call(func, *args)
        return call

    def visit_ArgListNode(self, node, ctx):
        args = self.visit_expr(node.parg_list, LOAD)
        kwargs = self.visit_expr(node.karg_list, LOAD)
        return (args, kwargs)

    def visit_GetUDNNode(self, node, ctx):
        assert (ctx == LOAD), ctx
        expr = self.visit_expr(node.expression, ctx)
        if self.strict_resolution:
            udn = ast.Attribute(expr, node.name, LOAD)
        else:
            udn = ast.TemplateDottedNotation(expr, node.name, LOAD)
        return udn

    def visit_GetAttrNode(self, node, ctx):
        expr = self.visit_expr(node.expression, ctx)
        attr = ast.Attribute(expr, node.name, LOAD)
        return attr

    def visit_PlaceholderNode(self, node, ctx):
        assert (ctx is LOAD)
        return self.make_placeholder(node, ctx)

    def visit_PlaceholderSubstitutionNode(self, node):
        expr = self.visit_expr(node.expression, LOAD)
        for n in node.parameter_list:
            if (n.name == 'raw'):
                break
            if (n.name == 'filter'):
                the_filter = self.make_name(n.default.name, LOAD)
                expr = ast.Filter(expr, the_filter)
        else:
            expr = ast.Filter(expr, None)
        stmt = ast.TemplateWrite(expr)
        self.originals[stmt] = node
        return stmt

    def visit_TextNode(self, node):
        text = ast.Str(node.value, 'u', None)
        self.originals[text] = node
        return ast.TemplateWrite(text)

    def visit_ReturnNode(self, node):
        (expr, end) = self.visit_expr(node.expression, LOAD)
        return ast.Return(expr)

    def visit_SliceNode(self, node, ctx):
        val = self.visit_expr(node.expression, ctx)
        index = self.visit_expr(node.slice_expression, ctx)
        subscr = ast.Subscript(val, index, ctx)
        return subscr

    def visit_IfNode(self, node):
        test = self.visit_expr(node.test_expression, LOAD)
        body = self.visit_stmt(node.child_nodes)
        if node.else_.child_nodes:
            orelse = self.visit_stmt(node.else_)
        else:
            orelse = None
        if (not node.pos):
            node.pos = node.test_expression.pos
        return ast.If(test, body, orelse)

    def visit_ElseNode(self, node):
        return self.visit_stmt(node.child_nodes)

    def visit_DictLiteralNode(self, node, ctx):
        items = []
        for (k, v) in node.child_nodes:
            key = self.visit_expr(k, LOAD)
            value = self.visit_expr(v, LOAD)
            items.append(ast.KeyValuePair(key, value))
        return ast.Dict(items)

    def visit_ListLiteralNode(self, node, ctx):
        assert ctx
        elts = self.visit_expr(node.child_nodes, ctx)
        result = ast.List(elts, ctx)
        return result

    def visit_TupleLiteralNode(self, node, ctx):
        assert ctx
        elts = self.visit_expr(node.child_nodes, ctx)
        result = ast.Tuple(elts, ctx)
        self.originals[result] = node
        return result

    def visit_MacroNode(self, node):
        macro_ast = util.parse(node.value, 'i18n_goal')
        stmts = []
        for n in macro_ast.child_nodes:
            if isinstance(n, PlaceholderSubstitutionNode):
                stmts.append(self.visit_stmt(n))
            else:
                text = ast.Str(n.value, 'u', None)
                self.originals[text] = n
                stmts.append(ast.TemplateWrite(text))
        return stmts

    def visit_BinOpExpressionNode(self, node, ctx):
        assert (ctx == LOAD)
        left = self.visit_expr(node.left, ctx)
        right = self.visit_expr(node.right, ctx)
        op = node.operator
        if (op in ARITHMETIC_OPS):
            expr = ast.BinOp(left, ARITHMETIC_OPS[op](), right)
        elif (op in BOOL_OPS):
            expr = ast.BoolOp(BOOL_OPS[op](), [left, right])
        elif (op in COMPARE_OPS):
            expr = ast.Compare(left, [COMPARE_OPS[op]()], [right])
        else:
            raise SyntaxError(('Unrecognised operator ' + op))
        self.originals[expr] = node
        return expr
    visit_BinOpNode = visit_BinOpExpressionNode

    def visit_AssignNode(self, node):
        if (node.left.pos is None):
            node.left.pos = node.pos
        left = self.visit_expr(node.left, STORE)
        right = self.visit_expr(node.right, LOAD)
        asgn = ast.Assign(right, [left])
        return asgn

    def visit_UnaryOpNode(self, node, ctx):
        assert (ctx == LOAD)
        operand = self.visit_expr(node.expression, ctx)
        op = node.operator
        if (op in UNARY_OPS):
            unary = ast.UnaryOp(UNARY_OPS[op](), operand)
        else:
            raise SyntaxError(('Unrecognised operator ' + op))
        return unary

    def visit_FunctionNode(self, node):
        raise NotImplementedError()

    def visit_BufferWrite(self, node):
        (expr, end) = self.visit_expr(node.expression, LOAD)
        buf = ast.TemplateWrite(expr)
        return buf

    def visit_IdentifierNode(self, node, ctx):
        assert node.name
        assert ctx
        return self.make_name(node, ctx)

    def visit_LiteralNode(self, node, ctx):
        assert (ctx == LOAD), ctx
        val = node.value
        if isinstance(val, bool):
            ast_node = self.make_name(repr(val), LOAD, self.module_scope)
        elif isinstance(val, bytes):
            ast_node = ast.Str(val, 'b', None)
        elif isinstance(val, unicode):
            ast_node = ast.Str(val, 'u', None)
        elif isinstance(val, (float, int)):
            ast_node = ast.Num(val, str(val))
        else:
            raise TypeError('Unknown literal type')
        return ast_node

    def visit_EchoNode(self, node):
        expr = self.visit_expr(node.true_expression, LOAD)
        body = ast.TemplateWrite(expr)
        if (node.test_expression is None):
            return body
        test = self.visit_expr(node.test_expression, LOAD)
        if (node.false_expression is None):
            orelse = None
        else:
            false_expr = self.visit_expr(node.false_expression, LOAD)
            orelse = ast.TemplateWrite(false_expr)
        return ast.If(test, [body], [orelse])

    def visit_CommentNode(self, node):
        pass

    def visit_BreakNode(self, node):
        return ast.Break()

    def visit_ContinueNode(self, node):
        return ast.Continue()

    def visit_StripLinesNode(self, node):
        return self.visit_stmt(node.child_nodes)

    def visit_DefNode(self, node):
        params = []
        body = []
        func = ast.Function(node.name, params, None, None, None, body, False)
        saved_scope = self.function_scope
        self.function_scope = func
        args = self.visit_stmt(node.parameter_list)
        self_arg = self.make_name('self', PARAM)
        body.extend(self.visit_stmt(node.child_nodes))
        params.append(self_arg)
        for p in node.parameter_list:
            param = self.make_name(p, PARAM)
            params.append(param)
        self.function_scope = saved_scope
        self.originals[func] = node
        funcexpr = ast.FunctionExpr(node.name, args, None, func)
        self.originals[funcexpr] = node
        name = self.make_name(node, STORE, self.class_scope)
        self.originals[name] = node
        funcdef = ast.Assign(funcexpr, [name])
        self.originals[funcdef] = node
        self.class_members.append(funcdef)
        return None

    def visit_ParameterListNode(self, node):
        args = self.visit_expr_list(node, LOAD)
        defaults = [(kw.value if isinstance(kw, ast.keyword) else None) for kw in args]
        annotations = [None for _ in defaults]
        return ast.arguments(defaults, [], annotations, None, None, [])

    def visit_ParameterNode(self, node, ctx):
        assert (ctx is LOAD)
        if node.default:
            default = self.visit_expr(node.default, LOAD)
            result = ast.keyword(node.name, default)
        else:
            result = self.make_name(node, ctx)
        return result

    def visit_expr_or_tuple(self, node, ctx):
        items = node.child_nodes
        if (len(items) == 1):
            return self.visit_expr(items[0], ctx)
        else:
            t = ast.Tuple(self.visit_expr(items, ctx), ctx)
            return t

    def visit_TargetNode(self, node, ctx):
        assert (ctx is STORE)
        return self.make_name(node, ctx)

    def visit_BlockNode(self, node):
        self.visit_DefNode(node)
        ph = self.make_placeholder(node, LOAD)
        call = make_call(ph)
        filter = ast.Filter(call, None)
        self.originals[filter] = node
        return ast.TemplateWrite(filter)

    def _make_name(self, cls, node, ctx, scope):
        if isinstance(node, (str, unicode)):
            variable = ast.Variable(node, scope)
            result = cls(variable, ctx)
        else:
            variable = ast.Variable(node.name, scope)
            result = cls(variable, ctx)
            self.originals[result] = node
        return result

    def make_name(self, node, ctx, scope=None):
        return self._make_name(ast.Name, node, ctx, scope)

    def make_placeholder(self, node, ctx):
        return self._make_name(ast.PlaceHolder, node, ctx, None)
    visit_TargetListNode = visit_expr_or_tuple
    visit_ExpressionListNode = visit_expr_or_tuple
    visit_WhitespaceNode = visit_TextNode
    visit_OptionalWhitespaceNode = visit_TextNode
    visit_NewlineNode = visit_TextNode

    def wrap_in_call(self, expr, node, funcname, scope):
        func = self.make_name(funcname, LOAD, scope)
        call = make_call(func, [expr])
        self.originals[func] = self.originals[call] = node
        return call

    def create_class_expr(self, name, bases, keywords, body):
        cls = ast.Class(name, body)
        clsexpr = ast.ClassExpr(name, bases, keywords, cls)
        return clsexpr

    def create_class_def(self, name, bases, keywords, body, scope):
        clsexpr = self.create_class_expr(name, bases, keywords, body)
        asgn = ast.Assign(clsexpr, [self.make_name(name, STORE, scope)])
        return asgn

    def create_func_expr(self, name, params, arguments, vararg, kwonlyargs, kwarg, returns_annotation, body):
        func = ast.Function(name, params, vararg, kwonlyargs, kwarg, body, False)
        funcexpr = ast.FunctionExpr(name, arguments, returns_annotation, func)
        return funcexpr

class Locator(object):

    def __init__(self, lines, originals):
        self._visit_methods = {}
        self.originals = originals
        self.start = (1, 1)
        self.lines = lines
        self.last_line = 1
        line_start = 0
        self.line_starts = []
        for line in self.lines:
            self.line_starts.append(line_start)
            line_start += (len(line) + 1)

    def line_col_from_offset(self, pos):
        if (pos < 0):
            return (0, (- 1))
        line = bisect.bisect_right(self.line_starts, pos)
        col = (pos - self.line_starts[(line - 1)])
        return (line, col)

    def _get_start(self, node, start):
        if (node not in self.originals):
            return start
        orig = self.originals[node]
        if (not hasattr(orig, 'pos')):
            return start
        if orig.pos:
            return self.line_col_from_offset(orig.pos)
        return start

    def traverse(self, node, start=(1, 1)):
        start = self._get_start(node, start)
        end = (start[0], (start[1] + 1))
        for (_, _, value) in iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AstBase):
                        item_end = self.traverse(item, start)
                        if (item_end > end):
                            end = item_end
            elif isinstance(value, ast.AstBase):
                item_end = self.traverse(value, start)
                if (item_end > end):
                    end = item_end
        end = self.visit(node, start, end)
        return end

    def visit(self, node, start, end):
        try:
            visit_method = self._visit_methods[type(node)]
        except KeyError:
            if hasattr(self, ('visit_' + type(node).__name__)):
                visit_method = getattr(self, ('visit_' + type(node).__name__))
            else:
                visit_method = self.visit_generic
            self._visit_methods[type(node)] = visit_method
        if hasattr(node, 'lineno'):
            return node._end
        return visit_method(node, start, end)

    def visit_TemplateWrite(self, node, start, end):
        (node.lineno, node.col_offset) = (node.value.lineno, node.value.col_offset)
        node._end = node.value._end
        return node._end

    def visit_Filter(self, node, start, end):
        (node.lineno, node.col_offset) = (node.value.lineno, node.value.col_offset)
        node._end = node.value._end
        return node._end

    def visit_generic(self, node, start, end):
        if hasattr(node, 'lineno'):
            return node._end
        else:
            (node.lineno, node.col_offset) = start
            node._end = end
            return end

    def visit_Num(self, node, start, end):
        (node.lineno, node.col_offset) = start
        node._end = (start[0], ((start[1] + len(node.text)) - 1))
        return node._end

    def visit_Str(self, node, start, end):
        if (self.originals[node].pos is not None):
            start = self.line_col_from_offset(self.originals[node].pos)
            self.last_line = start[0]
        elif (node.s == '\n'):
            line = self.last_line
            if (line < len(self.lines)):
                self.last_line += 1
            col = len(self.lines[(line - 1)])
            (node.lineno, node.col_offset) = (line, col)
            node._end = (line, (col + 1))
            return node._end
        else:
            start = (self.last_line, 1)
        return self.locate_text(node, start, node.s)

    def visit_Name(self, node, start, end):
        return self.locate_text(node, start, node.id)

    def visit_PlaceHolder(self, node, start, end):
        return self.locate_text(node, start, node.id)

    def locate_text(self, node, start, text):

        def locate(lineno, col):
            line = self.lines[(lineno - 1)]
            node.lineno = lineno
            if (line[col:(col + length)] == match):
                node.col_offset = col
                node._end = (lineno, (col + length))
                return True
            elif (line[(col - length):col] == match):
                node.col_offset = (col - length)
                node._end = (lineno, col)
                return True
            else:
                col = line.find(match)
                if (col < 0):
                    node.lineno = 0
                    node.col_offset = (- 1)
                    node._end = (0, 0)
                    self.last_line = 0
                    return False
                else:
                    node.col_offset = line.find(match)
                    node._end = (lineno, (node.col_offset + length))
                    return True
        if ((node in self.originals) and (self.originals[node].start is not None)):
            orig = self.originals[node]
            (node.lineno, node.col_offset) = self.line_col_from_offset(orig.start)
            node._end = self.line_col_from_offset(orig.end)
        else:
            match = text.replace("'", "\\'").replace('"', '\\"')
            length = len(match)
            if (not locate(start[0], (start[1] - 1))):
                locate(self.last_line, 0)
        return node._end

    def visit_Attribute(self, node, start, end):
        (node.lineno, node.col_offset) = (node.value.lineno, node.value.col_offset)
        val_end = node.value._end
        if (val_end == (0, 0)):
            node._end = val_end
        else:
            node._end = (val_end[0], ((val_end[1] + len(node.attr)) + 1))
        return node._end
    visit_TemplateDottedNotation = visit_Attribute

def make_call(func, pos=None, named=None):
    call = ast.Call(func, (pos if pos else []), (named if named else []))
    return call

class ASTVisitor(object):

    def __init__(self):
        self.method = {}

    def visit(self, node):
        method = ('visit_' + node.__class__.__name__)
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        for (name, _, value) in iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AstBase):
                        self.visit(item)
            elif isinstance(value, ast.AstBase):
                self.visit(value)

class Symbolizer(ASTVisitor):

    def __init__(self):
        ASTVisitor.__init__(self)
        self.defined_variables = {}
        self.scope_stack = []

    def visit_Name(self, node):
        t = type(node.ctx)
        if (t in (ast.Store, ast.Param)):
            if (node.variable.scope is not None):
                self.defined_variables[node.variable.scope].add(node.id)
            else:
                self.defined_variables[self.scope_stack[(- 1)]].add(node.id)

    def visit_scope(self, node):
        self.scope_stack.append(node)
        self.defined_variables[node] = set()
        self.generic_visit(node)
        self.scope_stack.pop()
    visit_Function = visit_scope
    visit_Class = visit_scope
    visit_Module = visit_scope
    visit_PlaceHolder = visit_Name

class Variablizer(ASTVisitor):

    def __init__(self, defined_variables):
        ASTVisitor.__init__(self)
        self.scope_stack = []
        self.defined_variables = defined_variables

    def visit_Name(self, node):
        var = node.variable
        if (var.scope is not None):
            return
        current_scope = self.scope_stack[(- 1)]
        if (var.id in self.defined_variables[current_scope]):
            scope = current_scope
        else:
            scope = self.scope_stack[0]
        assert (scope is not None)
        var.scope = scope

    def visit_scope(self, node):
        self.scope_stack.append(node)
        self.generic_visit(node)
        self.scope_stack.pop()
    visit_Function = visit_scope
    visit_Class = visit_scope
    visit_Template = visit_scope
    visit_PlaceHolder = visit_Name

class NodeTransformer(ASTVisitor):

    def generic_visit(self, node):
        for (field, _, old_value) in iter_fields(node):
            old_value = getattr(node, field, None)
            if isinstance(old_value, list):
                new_values = []
                for value in old_value:
                    if isinstance(value, ast.AstBase):
                        value = self.visit(value)
                        if (value is None):
                            continue
                        elif (not isinstance(value, ast.AstBase)):
                            new_values.extend(value)
                            continue
                    new_values.append(value)
                old_value[:] = new_values
            elif isinstance(old_value, ast.AstBase):
                new_node = self.visit(old_value)
                if (new_node is None):
                    delattr(node, field)
                else:
                    setattr(node, field, new_node)
        return node

class TextMerger(NodeTransformer):

    def __init__(self):
        NodeTransformer.__init__(self)
        self.working_node = None

    def merge(self, tree):
        self.node_lists = {}
        tree = self.visit(tree)
        for (node, nodes) in self.node_lists.items():
            merge_nodes(node, nodes)
        del self.node_lists
        return tree

    def generic_visit(self, node):
        if (node is None):
            return None
        self.working_node = None
        res = NodeTransformer.generic_visit(self, node)
        self.working_node = None
        return res

    def visit_TemplateWrite(self, node):
        if (not isinstance(node.value, ast.Str)):
            self.working_node = None
            return node
        if (not node.value.s):
            return None
        if (self.working_node is not None):
            nodelist = self.node_lists[self.working_node]
            lastline = nodelist[(- 1)]._end[0]
            if (nodelist[(- 1)].value.s == '\n'):
                if (((lastline + 1) == node.lineno) and (node.col_offset == 0)):
                    nodelist.append(node)
                    return None
            elif ((lastline == node.lineno) and (nodelist[(- 1)]._end[1] == node.col_offset)):
                nodelist.append(node)
                return None
            self.working_node = node
        self.working_node = node
        self.node_lists[node] = [node]
        return node

def merge_nodes(node, nodes):
    node.value.s = ''.join([n.value.s for n in nodes])
    node._end = nodes[(- 1)]._end
    (node.value.lineno, node.value.col_offset) = (node.lineno, node.col_offset)
    node.value._end = node._end

def flatten_lists(node):
    if isinstance(node, list):
        new_list = []
        for item in node:
            flatten_lists(item)
            if isinstance(item, list):
                new_list.extend(item)
            else:
                new_list.append(item)
        node[:] = new_list
    elif isinstance(node, ast.AstBase):
        for (name, _, attr) in iter_fields(node):
            flatten_lists(attr)

def print_tree(node, indent=''):
    if (node is None):
        return
    elif isinstance(node, list):
        print((indent + '['))
        for item in node:
            print_tree(item, indent)
        print((indent + ']'))
    elif isinstance(node, ast.AstBase):
        if hasattr(node, 'lineno'):
            print(('%s%s@%d:%d/%d:%d:' % (indent, type(node).__name__, node.lineno, node.col_offset, node._end[0], node._end[1])))
        else:
            print(('%s%s:' % (indent, type(node).__name__)))
        indent += '    '
        for (name, _, attr) in iter_fields(node):
            print(('%s%s' % (indent, name)))
            print_tree(attr, (indent + '    '))
    else:
        print(('%s%r' % (indent, node)))
ARITHMETIC_OPS = {'+': ast.Add, '-': ast.Sub, '*': ast.Mult, '/': ast.Div, '%': ast.Mod, '**': ast.Pow, '<<': ast.LShift, '>>': ast.RShift, '&': ast.BitAnd, '|': ast.BitOr}
BOOL_OPS = {'and': ast.And, 'or': ast.Or}
COMPARE_OPS = {'<': ast.Lt, '>': ast.Gt, '<=': ast.LtE, '>=': ast.GtE, '==': ast.Eq, '<>': ast.NotEq, '!=': ast.NotEq, 'in': ast.In, 'is': ast.Is, 'is not': ast.IsNot}
UNARY_OPS = {'+': ast.UAdd, '-': ast.USub, '~': ast.Invert, 'not': ast.Not}

def transform_ast(node, lines, name, strict_resolution):
    translator = AstTranslator(strict_resolution)
    tree = translator.transform(name, node)
    symbolizer = Symbolizer()
    symbolizer.visit(tree)
    Variablizer(symbolizer.defined_variables).visit(tree)
    locator = Locator(lines, translator.originals)
    locator.traverse(tree)
    tree = TextMerger().merge(tree)
    return tree
if (__name__ == '__main__'):
    from spitfire.compiler import util
    src = open(sys.argv[1]).read()
    (_, filename) = os.path.split(sys.argv[1])
    the_ast = util.parse_file(sys.argv[1])
    if ('.' in filename):
        (name, _) = filename.split('.')
    else:
        name = filename
    py_ast = transform_ast(the_ast, src.split('\n'), name, False)
    print_tree(py_ast)
