
from __future__ import print_function
import sys
import os
import inspect
import pkgutil
from semmle.python import ast
from semmle.python.passes.exports import ExportsPass
from semmle.python.passes.lexical import LexicalPass
from semmle.python.passes.flow import FlowPass
from semmle.python.passes.ast_pass import ASTPass
from semmle.python.passes.objects import ObjectPass
from semmle.util import VERSION, PY2, uuid, get_analysis_version, get_analysis_major_version
from semmle.util import makedirs, unicode, str_to_unicode, get_source_file_tag, TrapWriter, base64digest
from semmle.cache import Cache
from semmle.logging import WARN
from semmle.profiling import timers
UTRAP_KEY = ('utrap%s' % VERSION)
__all__ = ['Extractor', 'CachingExtractor']
PY_VERSION = (2 if PY2 else 3)
if PY2:
    FLAG_SAVE_TYPES = (float, complex, bool, int, str, unicode, long)
else:
    FLAG_SAVE_TYPES = (float, complex, bool, int, bytes, str)

class Extractor(object):

    def __init__(self, trap_folder, src_archive, options, logger):
        assert trap_folder
        self.trap_folder = trap_folder
        self.src_archive = src_archive
        self.object_pass = ObjectPass()
        self.passes = [ASTPass(), ExportsPass(), FlowPass(options.split, options.prune, options.unroll)]
        self.lexical = LexicalPass()
        self.files = {}
        self.options = options
        self.omit_hash = options.omit_hash
        self.handle_syntax_errors = (not options.no_syntax_errors)
        self.logger = logger

    def _handle_syntax_error(self, module, ex):
        self.logger.debug('Emitting trap for syntax error in %s', module.path)
        writer = TrapWriter()
        module_id = writer.get_node_id(module)
        line = (ex.lineno if ex.lineno else 0)
        if (line > len(module.lines)):
            line = len(module.lines)
            col = (len(module.lines[(- 1)]) - 1)
        else:
            col = (ex.offset if ex.offset else 0)
        loc_id = writer.get_unique_id()
        writer.write_tuple('locations_ast', 'rrdddd', loc_id, module_id, 0, 0, 0, 0)
        syntax_id = ('syntax%d:%d' % (line, col))
        writer.write_tuple('locations_ast', 'nrdddd', syntax_id, module_id, line, (col + 1), line, (col + 1))
        writer.write_tuple('py_syntax_error_versioned', 'nss', syntax_id, ex.msg, get_analysis_major_version())
        trap = writer.get_compressed()
        self.trap_folder.write_trap('syntax-error', module.path, trap)
        return ast.Module([])

    def _extract_trap_file(self, ast, comments, path):
        writer = TrapWriter()
        file_tag = get_source_file_tag(self.src_archive.get_virtual_path(path))
        writer.write_tuple('py_Modules', 'g', ast.trap_name)
        writer.write_tuple('py_module_path', 'gg', ast.trap_name, file_tag)
        try:
            for ex in self.passes:
                with timers[ex.name]:
                    ex.extract(ast, writer)
            with timers['lexical']:
                self.lexical.extract(ast, comments, writer)
            with timers['object']:
                self.object_pass.extract(ast, path, writer)
        except Exception as ex:
            self.logger.error('Exception extracting module %s: %s', path, ex)
            self.logger.traceback(WARN)
            return None
        return writer.get_compressed()

    def process_source_module(self, module):
        try:
            ast = module.ast
        except SyntaxError as ex:
            self.logger.debug('handle syntax errors is %s', self.handle_syntax_errors)
            if self.handle_syntax_errors:
                ast = self._handle_syntax_error(module, ex)
            else:
                return None
        ast.name = module.name
        ast.kind = module.kind
        ast.trap_name = module.trap_name
        return self.process_module(ast, module.trap_name, module.bytes_source, module.path, module.comments)

    def process_module(self, ast, module_tag, bytes_source, path, comments):
        self.logger.debug('Populating trap file for %s', path)
        ast.trap_name = module_tag
        trap = self._extract_trap_file(ast, comments, path)
        if (trap is None):
            return None
        with timers['trap']:
            self.trap_folder.write_trap('python', path, trap)
        try:
            with timers['archive']:
                self.copy_source(bytes_source, module_tag, path)
        except Exception:
            import traceback
            traceback.print_exc()
        return trap

    def copy_source(self, bytes_source, module_tag, path):
        if (bytes_source is None):
            return
        self.files[module_tag] = self.src_archive.get_virtual_path(path)
        self.src_archive.write(path, bytes_source)

    def write_interpreter_data(self, options):

        def write_flag(name, value):
            writer.write_tuple('py_flags_versioned', 'uus', name, value, get_analysis_major_version())

        def write_flags(obj, prefix):
            pre = (prefix + '.')
            for (name, value) in inspect.getmembers(obj):
                if (name[0] == '_'):
                    continue
                if (type(value) in FLAG_SAVE_TYPES):
                    write_flag((pre + str_to_unicode(name)), unicode(value))
        writer = TrapWriter()
        for (index, name) in enumerate(('major', 'minor', 'micro', 'releaselevel', 'serial')):
            writer.write_tuple('py_flags_versioned', 'sss', ('version.' + name), str(sys.version_info[index]), get_analysis_major_version())
        write_flags(sys.flags, 'flags')
        write_flags(sys.float_info, 'float')
        write_flags(self.options, 'options')
        write_flag('sys.prefix', sys.prefix)
        path = os.pathsep.join((os.path.abspath(p) for p in options.sys_path))
        write_flag('sys.path', path)
        if (options.path is None):
            path = ''
        else:
            path = os.pathsep.join((self.src_archive.get_virtual_path(p) for p in options.path))
        if options.language_version:
            write_flag('language.version', options.language_version[(- 1)])
        else:
            write_flag('language.version', get_analysis_version())
        write_flag('extractor.path', path)
        write_flag('sys.platform', sys.platform)
        write_flag('os.sep', os.sep)
        write_flag('os.pathsep', os.pathsep)
        write_flag('extractor.version', VERSION)
        if (options.context_cost is not None):
            write_flag('context.cost', options.context_cost)
        self.trap_folder.write_trap('flags', '$flags', writer.get_compressed())
        if (get_analysis_major_version() == 2):
            builtins_trap_data = pkgutil.get_data('semmle.data', 'interpreter2.trap')
            self.trap_folder.write_trap('interpreter', '$interpreter2', builtins_trap_data, extension='.trap')
        else:
            writer = TrapWriter()
            self.object_pass.write_special_objects(writer)
            self.trap_folder.write_trap('interpreter', '$interpreter3', writer.get_compressed())
        if (get_analysis_major_version() == 2):
            stdlib_trap_name = '$stdlib_27.trap'
        else:
            stdlib_trap_name = '$stdlib_33.trap'
        stdlib_trap_data = pkgutil.get_data('semmle.data', stdlib_trap_name)
        self.trap_folder.write_trap('stdlib', stdlib_trap_name[:(- 5)], stdlib_trap_data, extension='.trap')

    @staticmethod
    def from_options(options, trap_dir, archive, logger):
        try:
            trap_copy_dir = options.trap_cache
            caching_extractor = CachingExtractor(trap_copy_dir, options, logger)
        except Exception as ex:
            if (options.verbose and (trap_copy_dir is not None)):
                print(('Failed to create caching extractor: ' + str(ex)))
            caching_extractor = None
        worker = Extractor(trap_dir, archive, options, logger)
        if caching_extractor:
            caching_extractor.set_worker(worker)
            return caching_extractor
        else:
            return worker

    def stop(self):
        pass

    def close(self):
        if self.files:
            trapwriter = TrapWriter()
            for (_, filepath) in self.files.items():
                trapwriter.write_file(filepath)
            self.trap_folder.write_trap('folders', (uuid('python') + '/$files'), trapwriter.get_compressed())
            self.files = set()
        for (name, timer) in sorted(timers.items()):
            self.logger.debug("Total time for pass '%s': %0.0fms", name, timer.elapsed)

def hash_combine(x, y):
    return base64digest(((x + ':') + y))

class CachingExtractor(object):

    def __init__(self, cachedir, options, logger):
        if (cachedir is None):
            raise IOError('No cache directory')
        makedirs(cachedir)
        self.worker = None
        self.cache = Cache.for_directory(cachedir, options.verbose)
        self.logger = logger
        self.split = options.split

    def set_worker(self, worker):
        self.worker = worker

    def get_cache_key(self, module):
        key = hash_combine(module.path, module.source)
        if (not self.split):
            key = hash_combine(UTRAP_KEY, key)
        return hash_combine(key, module.source)

    def process_source_module(self, module):
        if (self.worker is None):
            raise Exception('worker is not set')
        key = self.get_cache_key(module)
        trap = self.cache.get(key)
        if (trap is None):
            trap = self.worker.process_source_module(module)
            if (trap is not None):
                self.cache.set(key, trap)
        else:
            self.logger.debug('Found cached trap file for %s', module.path)
            self.worker.trap_folder.write_trap('python', module.path, trap)
            try:
                self.worker.copy_source(module.bytes_source, module.trap_name, module.path)
            except Exception:
                self.logger.traceback(WARN)
        return trap

    def process_module(self, ast, module_tag, source_code, path, comments):
        self.worker.process_module(ast, module_tag, source_code, path, comments)

    def close(self):
        self.worker.close()

    def write_interpreter_data(self, sys_path):
        self.worker.write_interpreter_data(sys_path)

    def stop(self):
        self.worker.stop()
