
from __future__ import absolute_import
import sys
from semmle.python import extractor, finder, modules
from semmle import util
from semmle.extractors.base import BaseExtractor
from semmle.logging import WARN
strict_resolution = False

class LoadedSpitfireTemplate(modules.PythonSourceModule):
    kind = 'Spitfire template'

    def __init__(self, name, path, logger):
        import spitfire.compiler.util
        import semmle.extractors.spitfire.ast_transformer
        import spitfire.compiler.ast
        modules.PythonSourceModule.__init__(self, name, path, logger)
        with open(self.path, 'rbU') as src:
            code = src.read()
        self._source = code.decode('utf8')
        self._lines = self._source.split('\n')
        self._secure_hash = modules.base64digest(code)
        self._comments = None
        try:
            spitfire_ast = spitfire.compiler.util.parse_file(path)
            self._ast = semmle.extractors.spitfire.ast_transformer.transform_ast(spitfire_ast, self.lines, self.name, strict_resolution)
            self._ast.trap_name = self.trap_name
        except Exception as ex:
            logger.error('Failed to parse template %s: %r', path, ex)
            logger.traceback(WARN)
            raise ex

    @property
    def ast(self):
        return self._ast

    @property
    def source(self):
        return self._source

    def get_encoding(self):
        return 'utf8'

    @property
    def py_ast(self):
        return None

    def get_comments(self):
        if (self._comments is None):
            self._lexical()
        return self._comments

    def _lexical(self):
        self._comments = []
        for (i, line) in enumerate(self.lines):
            if line.startswith('##'):
                self._comments.append((line[2:], (i, 3), (i, len(line))))
            elif ('##' in line):
                col = line.index('##')
                self._comments.append((line[(col + 2):], (i, (col + 3)), (i, len(line))))

    def close(self):
        self._ast = None
        self._source = None

class SpitfireExtractor(BaseExtractor):
    name = 'spitfire extractor'

    def __init__(self, options, trap_folder, src_archive, logger):
        super(SpitfireExtractor, self).__init__(options, trap_folder, src_archive, logger)
        self.module_extractor = extractor.Extractor.from_options(options, trap_folder, src_archive, logger)
        self.finder = finder.Finder.from_options_and_env(options, logger)
        if options.spitfire_path:
            sys.path = (sys.path + [options.spitfire_path])

    def process(self, unit):
        if (not isinstance(unit, util.FileExtractable)):
            return NotImplemented
        if ((not unit.path.endswith('.tmpl')) and (not unit.path.endswith('.spt'))):
            return NotImplemented
        name = self.finder.name_from_path(unit.path, ('.tmpl', '.spt'))
        assert name, unit.path
        module = LoadedSpitfireTemplate(name, unit.path, self.logger)
        self.module_extractor.process_source_module(module)
        template_module = self.finder.find('spitfire.runtime.template')
        return [template_module.get_extractable()]

    def close(self):
        self.module_extractor.close()
