
from __future__ import print_function
from ast import literal_eval
import sys
import os
import semmle.python.parser
from semmle.python.parser.ast import copy_location, decode_str, split_string
from semmle.python import ast
import subprocess
from itertools import groupby
DEBUG = False

def debug_print(*args, **kwargs):
    if DEBUG:
        print(*args, **kwargs)

class Node(object):

    def __init__(self, id):
        self.id = id

    def __repr__(self):
        return 'Node({})'.format(self.id)

class Comment(object):

    def __init__(self, text):
        self.text = text

    def __repr__(self):
        return 'Comment({})'.format(self.text)

class SyntaxErrorNode(object):

    def __init__(self, source):
        self.source = source

    def __repr__(self):
        return 'SyntaxErrorNode({})'.format(self.source)
tsg_to_ast = {name: cls for (name, cls) in semmle.python.ast.__dict__.items() if (isinstance(cls, type) and (ast.AstBase in cls.__mro__))}
tsg_to_ast['Comment'] = Comment
tsg_to_ast['SyntaxErrorNode'] = SyntaxErrorNode
ast_fields = {ast.Module: ('body',), Comment: ('text',), SyntaxErrorNode: ('source',), ast.Continue: (), ast.Break: (), ast.Pass: (), ast.Ellipsis: (), ast.MatchWildcardPattern: ()}
ignored_fields = semmle.python.ast.AstBase.__slots__
for (name, cls) in semmle.python.ast.__dict__.items():
    if name.startswith('_'):
        continue
    if (not hasattr(cls, '__slots__')):
        continue
    slots = tuple((field for field in cls.__slots__ if (field not in ignored_fields)))
    if (not slots):
        continue
    ast_fields[cls] = slots
locationless = {'and': ast.And, 'or': ast.Or, 'not': ast.Not, 'uadd': ast.UAdd, 'usub': ast.USub, '+': ast.Add, '-': ast.Sub, '~': ast.Invert, '**': ast.Pow, '<<': ast.LShift, '>>': ast.RShift, '&': ast.BitAnd, '|': ast.BitOr, '^': ast.BitXor, 'load': ast.Load, 'store': ast.Store, 'del': ast.Del, 'param': ast.Param}
locationless.update(semmle.python.parser.ast.TERM_OP_CLASSES)
locationless.update(semmle.python.parser.ast.COMP_OP_CLASSES)
locationless.update(semmle.python.parser.ast.AUG_ASSIGN_OPS)
if ('CODEQL_EXTRACTOR_PYTHON_ROOT' in os.environ):
    platform = os.environ['CODEQL_PLATFORM']
    ext = ('.exe' if (platform == 'win64') else '')
    tools = os.path.join(os.environ['CODEQL_EXTRACTOR_PYTHON_ROOT'], 'tools', platform)
    tsg_command = [os.path.join(tools, ('tsg-python' + ext))]
else:
    script_path = os.path.dirname(os.path.realpath(__file__))
    tsg_python_path = os.path.join(script_path, '../../../tsg-python')
    cargo_file = os.path.join(tsg_python_path, 'Cargo.toml')
    tsg_command = ['cargo', 'run', '--quiet', '--release', ('--manifest-path=' + cargo_file)]

def read_tsg_python_output(path, logger):
    node_attr = {}
    edge_attr = {}
    command_args = (tsg_command + [path])
    p = subprocess.Popen(command_args, stdout=subprocess.PIPE)
    for line in p.stdout:
        line = line.decode(sys.getfilesystemencoding())
        line = line.rstrip()
        if line.startswith('node'):
            current_node = int(line.split(' ')[1])
            d = {}
            node_attr[current_node] = d
            in_node = True
        elif line.startswith('edge'):
            (current_start, current_end) = tuple(map(int, line[4:].split('->')))
            d = edge_attr.setdefault(current_start, {})
            in_node = False
        else:
            (key, value) = line[2:].split(': ', 1)
            if value.startswith('[graph node'):
                value = Node(int(value.split(' ')[2][:(- 1)]))
            elif (value == '#true'):
                value = True
            elif (value == '#false'):
                value = False
            elif (value == '#null'):
                value = None
            else:
                try:
                    if ((key == 's') and (value[0] == '"')):
                        value = evaluate_string(value)
                    else:
                        value = literal_eval(value)
                        if isinstance(value, bytes):
                            try:
                                value = value.decode(sys.getfilesystemencoding())
                            except UnicodeDecodeError:
                                pass
                except Exception as ex:
                    loc = ':'.join((str(i) for i in get_location_info(d)))
                    error = (ex.args[0] if ex.args else 'unknown')
                    logger.warning("Error '{}' while parsing value {} at {}:{}\n".format(error, repr(value), path, loc))
            if in_node:
                d[key] = value
            else:
                d.setdefault(key, []).append((value, current_end))
    p.stdout.close()
    p.terminate()
    p.wait()
    logger.info('Read {} nodes and {} edges from TSG output'.format(len(node_attr), len(edge_attr)))
    return (node_attr, edge_attr)

def evaluate_string(s):
    s = literal_eval(s)
    (prefix, quotes, content) = split_string(s, None)
    ends_with_illegal_character = False
    if (content.endswith(quotes[0]) or content.endswith('\\')):
        ends_with_illegal_character = True
        content = (content + ' ')
    s = (((prefix.strip('fF') + quotes) + content) + quotes)
    s = literal_eval(s)
    if isinstance(s, bytes):
        s = decode_str(s)
    if ends_with_illegal_character:
        s = s[:(- 1)]
    return s

def resolve_node_id(id, node_attr):
    while ('_skip_to' in node_attr[id]):
        id = node_attr[id]['_skip_to'].id
    return id

def get_context(id, node_attr, logger):
    while ('ctx' not in node_attr[id]):
        if ('_inherited_ctx' not in node_attr[id]):
            logger.error('No context for node {} with attributes {}\n'.format(id, node_attr[id]))
            return ast.Load()
        id = node_attr[id]['_inherited_ctx'].id
    return locationless[node_attr[id]['ctx']]()

def get_location_info(attrs):
    start_line = '???'
    start_column = '???'
    end_line = '???'
    end_column = '???'
    if ('_location' in attrs):
        (start_line, start_column, end_line, end_column) = attrs['_location']
    if ('_location_start' in attrs):
        (start_line, start_column) = attrs['_location_start']
    if ('_location_end' in attrs):
        (end_line, end_column) = attrs['_location_end']
    if ('_start_line' in attrs):
        start_line = attrs['_start_line']
    if ('_start_column' in attrs):
        start_column = attrs['_start_column']
    if ('_end_line' in attrs):
        end_line = attrs['_end_line']
    if ('_end_column' in attrs):
        end_column = attrs['_end_column']
    if (start_line != '???'):
        start_line += 1
    if (end_line != '???'):
        end_line += 1
    return (start_line, start_column, end_line, end_column)
list_fields = {ast.arguments: ('annotations', 'defaults', 'kw_defaults', 'kw_annotations'), ast.Assign: ('targets',), ast.BoolOp: ('values',), ast.Bytes: ('implicitly_concatenated_parts',), ast.Call: ('positional_args', 'named_args'), ast.Case: ('body',), ast.Class: ('body',), ast.ClassExpr: ('bases', 'keywords'), ast.Compare: ('ops', 'comparators'), ast.comprehension: ('ifs',), ast.Delete: ('targets',), ast.Dict: ('items',), ast.ExceptStmt: ('body',), ast.For: ('body',), ast.Function: ('args', 'kwonlyargs', 'body'), ast.Global: ('names',), ast.If: ('body',), ast.Import: ('names',), ast.List: ('elts',), ast.Match: ('cases',), ast.MatchClassPattern: ('positional', 'keyword'), ast.MatchMappingPattern: ('mappings',), ast.MatchOrPattern: ('patterns',), ast.MatchSequencePattern: ('patterns',), ast.Module: ('body',), ast.Nonlocal: ('names',), ast.Print: ('values',), ast.Set: ('elts',), ast.Str: ('implicitly_concatenated_parts',), ast.Try: ('body', 'handlers', 'orelse', 'finalbody'), ast.Tuple: ('elts',), ast.While: ('body',)}

def create_placeholder_args(cls):
    if (cls in (ast.Raise, ast.Ellipsis)):
        return {}
    fields = ast_fields[cls]
    args = {field: None for field in fields if (field != 'is_async')}
    for field in list_fields.get(cls, ()):
        args[field] = []
    if (cls in (ast.GeneratorExp, ast.ListComp, ast.SetComp, ast.DictComp)):
        del args['function']
        del args['iterable']
    return args

def parse(path, logger):
    (node_attr, edge_attr) = read_tsg_python_output(path, logger)
    debug_print('node_attr:', node_attr)
    debug_print('edge_attr:', edge_attr)
    nodes = {}
    fixups = {}
    node_id = {}
    for (id, attrs) in node_attr.items():
        if ('_is_literal' in attrs):
            nodes[id] = attrs['_is_literal']
            continue
        if ('_kind' not in attrs):
            logger.error('Error: Graph node {} with attributes {} has no `_kind`!\n'.format(id, attrs))
            continue
        if ('_skip_to' in attrs):
            continue
        cls = tsg_to_ast[attrs['_kind']]
        args = ast_fields[cls]
        obj = cls(**create_placeholder_args(cls))
        nodes[id] = obj
        node_id[obj] = id
        if ('_fixup' in attrs):
            fixups[id] = obj
    for (id, node) in nodes.items():
        attrs = node_attr[id]
        if ('_is_literal' in attrs):
            continue
        expected_fields = ast_fields[type(node)]
        (node.lineno, node.col_offset, end_line, end_column) = get_location_info(attrs)
        node._end = (end_line, end_column)
        if isinstance(node, SyntaxErrorNode):
            exc = SyntaxError('Syntax Error')
            exc.lineno = node.lineno
            exc.offset = node.col_offset
            raise exc
        if ('ctx' in expected_fields):
            node.ctx = get_context(id, node_attr, logger)
        for (field, val) in attrs.items():
            if field.startswith('_'):
                continue
            if (field == 'ctx'):
                continue
            if ((field != 'parenthesised') and (field not in expected_fields)):
                logger.warning('Unknown field {} found among {} in node {}\n'.format(field, attrs, id))
            if isinstance(val, Node):
                val = resolve_node_id(val.id, node_attr)
                setattr(node, field, nodes[val])
            elif (isinstance(node, ast.Num) and (field == 'n')):
                node.n = literal_eval(val.rstrip('lL'))
            elif (isinstance(node, ast.Name) and (field == 'variable')):
                node.variable = ast.Variable(val)
            elif ((field == 'op') and (val in locationless.keys())):
                setattr(node, field, locationless[val]())
            else:
                setattr(node, field, val)
    for (start, field_map) in edge_attr.items():
        start = resolve_node_id(start, node_attr)
        parent = nodes[start]
        extra_fields = {}
        for (field_name, value_end) in field_map.items():
            children = [nodes[resolve_node_id(end, node_attr)] for (_index, end) in sorted(value_end)]
            children = [child for child in children if (not isinstance(child, Comment))]
            if (isinstance(parent, ast.Compare) and (field_name == 'ops')):
                parent.ops = [locationless[v]() for v in children]
            elif field_name.startswith('_'):
                extra_fields[field_name] = children
            else:
                setattr(parent, field_name, children)
        if extra_fields:
            node_attr[start].update(extra_fields)
    for (id, node) in fixups.items():
        if isinstance(node, (ast.JoinedStr, ast.Str)):
            fix_strings(id, node, node_attr, node_id, logger)
    debug_print('nodes:', nodes)
    if (not nodes):
        if (os.path.getsize(path) == 0):
            module = ast.Module([])
            module.lineno = 1
            module.col_offset = 0
            module._end = (1, 0)
            return module
        else:
            raise SyntaxError('Syntax Error')
    module = nodes[0]
    if module.body:
        module.lineno = module.body[0].lineno
    else:
        module.lineno = module._end[0]
    return module

def get_JoinedStr_children(children):
    for child in children:
        if isinstance(child, ast.JoinedStr):
            for value in child.values:
                (yield value)
        elif isinstance(child, ast.StringPart):
            (yield child)
        else:
            raise ValueError('Unexpected node type: {}'.format(type(child)))

def concatenate_stringparts(stringparts, logger):
    try:
        return ''.join((decode_str(stringpart.s) for stringpart in stringparts))
    except Exception as ex:
        logger.error('Unable to concatenate string %s getting error %s', stringparts, ex)
        return stringparts[0].s

def fix_strings(id, node, node_attr, node_id, logger):
    is_string = (lambda node: isinstance(node, ast.StringPart))
    children = node_attr[id].get('_children', [node])
    if isinstance(node, ast.Str):
        node.implicitly_concatenated_parts = children
        node.s = concatenate_stringparts(children, logger)
        node.prefix = children[0].prefix
    else:
        flattened_children = get_JoinedStr_children(children)
        groups = [list(n) for (_, n) in groupby(flattened_children, key=is_string)]
        combined_values = []
        for group in groups:
            first = group[0]
            if isinstance(first, ast.expr):
                combined_values.extend(group)
            else:
                combined_string = concatenate_stringparts(group, logger)
                str_node = ast.Str(combined_string, first.prefix, None)
                copy_location(first, str_node)
                str_node._end = group[(- 1)]._end
                if (len(group) > 1):
                    str_node.implicitly_concatenated_parts = group
                combined_values.append(str_node)
        node.values = combined_values
