377 lines
13 KiB
Python
377 lines
13 KiB
Python
|
from __future__ import absolute_import
|
||
|
|
||
|
import hashlib
|
||
|
import inspect
|
||
|
import os
|
||
|
import re
|
||
|
import sys
|
||
|
|
||
|
from distutils.core import Distribution, Extension
|
||
|
from distutils.command.build_ext import build_ext
|
||
|
|
||
|
import Cython
|
||
|
from ..Compiler.Main import Context, default_options
|
||
|
|
||
|
from ..Compiler.Visitor import CythonTransform, EnvTransform
|
||
|
from ..Compiler.ParseTreeTransforms import SkipDeclarations
|
||
|
from ..Compiler.TreeFragment import parse_from_strings
|
||
|
from ..Compiler.StringEncoding import _unicode
|
||
|
from .Dependencies import strip_string_literals, cythonize, cached_function
|
||
|
from ..Compiler import Pipeline
|
||
|
from ..Utils import get_cython_cache_dir
|
||
|
import cython as cython_module
|
||
|
|
||
|
|
||
|
IS_PY3 = sys.version_info >= (3,)
|
||
|
|
||
|
# A utility function to convert user-supplied ASCII strings to unicode.
|
||
|
if not IS_PY3:
|
||
|
def to_unicode(s):
|
||
|
if isinstance(s, bytes):
|
||
|
return s.decode('ascii')
|
||
|
else:
|
||
|
return s
|
||
|
else:
|
||
|
to_unicode = lambda x: x
|
||
|
|
||
|
if sys.version_info < (3, 5):
|
||
|
import imp
|
||
|
def load_dynamic(name, module_path):
|
||
|
return imp.load_dynamic(name, module_path)
|
||
|
else:
|
||
|
import importlib.util as _importlib_util
|
||
|
def load_dynamic(name, module_path):
|
||
|
spec = _importlib_util.spec_from_file_location(name, module_path)
|
||
|
module = _importlib_util.module_from_spec(spec)
|
||
|
# sys.modules[name] = module
|
||
|
spec.loader.exec_module(module)
|
||
|
return module
|
||
|
|
||
|
class UnboundSymbols(EnvTransform, SkipDeclarations):
|
||
|
def __init__(self):
|
||
|
CythonTransform.__init__(self, None)
|
||
|
self.unbound = set()
|
||
|
def visit_NameNode(self, node):
|
||
|
if not self.current_env().lookup(node.name):
|
||
|
self.unbound.add(node.name)
|
||
|
return node
|
||
|
def __call__(self, node):
|
||
|
super(UnboundSymbols, self).__call__(node)
|
||
|
return self.unbound
|
||
|
|
||
|
|
||
|
@cached_function
|
||
|
def unbound_symbols(code, context=None):
|
||
|
code = to_unicode(code)
|
||
|
if context is None:
|
||
|
context = Context([], default_options)
|
||
|
from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
|
||
|
tree = parse_from_strings('(tree fragment)', code)
|
||
|
for phase in Pipeline.create_pipeline(context, 'pyx'):
|
||
|
if phase is None:
|
||
|
continue
|
||
|
tree = phase(tree)
|
||
|
if isinstance(phase, AnalyseDeclarationsTransform):
|
||
|
break
|
||
|
try:
|
||
|
import builtins
|
||
|
except ImportError:
|
||
|
import __builtin__ as builtins
|
||
|
return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
|
||
|
|
||
|
|
||
|
def unsafe_type(arg, context=None):
|
||
|
py_type = type(arg)
|
||
|
if py_type is int:
|
||
|
return 'long'
|
||
|
else:
|
||
|
return safe_type(arg, context)
|
||
|
|
||
|
|
||
|
def safe_type(arg, context=None):
|
||
|
py_type = type(arg)
|
||
|
if py_type in (list, tuple, dict, str):
|
||
|
return py_type.__name__
|
||
|
elif py_type is complex:
|
||
|
return 'double complex'
|
||
|
elif py_type is float:
|
||
|
return 'double'
|
||
|
elif py_type is bool:
|
||
|
return 'bint'
|
||
|
elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
|
||
|
return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
|
||
|
else:
|
||
|
for base_type in py_type.__mro__:
|
||
|
if base_type.__module__ in ('__builtin__', 'builtins'):
|
||
|
return 'object'
|
||
|
module = context.find_module(base_type.__module__, need_pxd=False)
|
||
|
if module:
|
||
|
entry = module.lookup(base_type.__name__)
|
||
|
if entry.is_type:
|
||
|
return '%s.%s' % (base_type.__module__, base_type.__name__)
|
||
|
return 'object'
|
||
|
|
||
|
|
||
|
def _get_build_extension():
|
||
|
dist = Distribution()
|
||
|
# Ensure the build respects distutils configuration by parsing
|
||
|
# the configuration files
|
||
|
config_files = dist.find_config_files()
|
||
|
dist.parse_config_files(config_files)
|
||
|
build_extension = build_ext(dist)
|
||
|
build_extension.finalize_options()
|
||
|
return build_extension
|
||
|
|
||
|
|
||
|
@cached_function
|
||
|
def _create_context(cython_include_dirs):
|
||
|
return Context(list(cython_include_dirs), default_options)
|
||
|
|
||
|
|
||
|
_cython_inline_cache = {}
|
||
|
_cython_inline_default_context = _create_context(('.',))
|
||
|
|
||
|
|
||
|
def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
|
||
|
for symbol in unbound_symbols:
|
||
|
if symbol not in kwds:
|
||
|
if locals is None or globals is None:
|
||
|
calling_frame = inspect.currentframe().f_back.f_back.f_back
|
||
|
if locals is None:
|
||
|
locals = calling_frame.f_locals
|
||
|
if globals is None:
|
||
|
globals = calling_frame.f_globals
|
||
|
if symbol in locals:
|
||
|
kwds[symbol] = locals[symbol]
|
||
|
elif symbol in globals:
|
||
|
kwds[symbol] = globals[symbol]
|
||
|
else:
|
||
|
print("Couldn't find %r" % symbol)
|
||
|
|
||
|
|
||
|
def _inline_key(orig_code, arg_sigs, language_level):
|
||
|
key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
|
||
|
return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
|
||
|
|
||
|
|
||
|
def cython_inline(code, get_type=unsafe_type,
|
||
|
lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
|
||
|
cython_include_dirs=None, cython_compiler_directives=None,
|
||
|
force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):
|
||
|
|
||
|
if get_type is None:
|
||
|
get_type = lambda x: 'object'
|
||
|
ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
|
||
|
|
||
|
cython_compiler_directives = dict(cython_compiler_directives) if cython_compiler_directives else {}
|
||
|
if language_level is None and 'language_level' not in cython_compiler_directives:
|
||
|
language_level = '3str'
|
||
|
if language_level is not None:
|
||
|
cython_compiler_directives['language_level'] = language_level
|
||
|
|
||
|
# Fast path if this has been called in this session.
|
||
|
_unbound_symbols = _cython_inline_cache.get(code)
|
||
|
if _unbound_symbols is not None:
|
||
|
_populate_unbound(kwds, _unbound_symbols, locals, globals)
|
||
|
args = sorted(kwds.items())
|
||
|
arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
|
||
|
key_hash = _inline_key(code, arg_sigs, language_level)
|
||
|
invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
|
||
|
if invoke is not None:
|
||
|
arg_list = [arg[1] for arg in args]
|
||
|
return invoke(*arg_list)
|
||
|
|
||
|
orig_code = code
|
||
|
code = to_unicode(code)
|
||
|
code, literals = strip_string_literals(code)
|
||
|
code = strip_common_indent(code)
|
||
|
if locals is None:
|
||
|
locals = inspect.currentframe().f_back.f_back.f_locals
|
||
|
if globals is None:
|
||
|
globals = inspect.currentframe().f_back.f_back.f_globals
|
||
|
try:
|
||
|
_cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
|
||
|
_populate_unbound(kwds, _unbound_symbols, locals, globals)
|
||
|
except AssertionError:
|
||
|
if not quiet:
|
||
|
# Parsing from strings not fully supported (e.g. cimports).
|
||
|
print("Could not parse code as a string (to extract unbound symbols).")
|
||
|
|
||
|
cimports = []
|
||
|
for name, arg in list(kwds.items()):
|
||
|
if arg is cython_module:
|
||
|
cimports.append('\ncimport cython as %s' % name)
|
||
|
del kwds[name]
|
||
|
arg_names = sorted(kwds)
|
||
|
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
|
||
|
key_hash = _inline_key(orig_code, arg_sigs, language_level)
|
||
|
module_name = "_cython_inline_" + key_hash
|
||
|
|
||
|
if module_name in sys.modules:
|
||
|
module = sys.modules[module_name]
|
||
|
|
||
|
else:
|
||
|
build_extension = None
|
||
|
if cython_inline.so_ext is None:
|
||
|
# Figure out and cache current extension suffix
|
||
|
build_extension = _get_build_extension()
|
||
|
cython_inline.so_ext = build_extension.get_ext_filename('')
|
||
|
|
||
|
module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
|
||
|
|
||
|
if not os.path.exists(lib_dir):
|
||
|
os.makedirs(lib_dir)
|
||
|
if force or not os.path.isfile(module_path):
|
||
|
cflags = []
|
||
|
c_include_dirs = []
|
||
|
qualified = re.compile(r'([.\w]+)[.]')
|
||
|
for type, _ in arg_sigs:
|
||
|
m = qualified.match(type)
|
||
|
if m:
|
||
|
cimports.append('\ncimport %s' % m.groups()[0])
|
||
|
# one special case
|
||
|
if m.groups()[0] == 'numpy':
|
||
|
import numpy
|
||
|
c_include_dirs.append(numpy.get_include())
|
||
|
# cflags.append('-Wno-unused')
|
||
|
module_body, func_body = extract_func_code(code)
|
||
|
params = ', '.join(['%s %s' % a for a in arg_sigs])
|
||
|
module_code = """
|
||
|
%(module_body)s
|
||
|
%(cimports)s
|
||
|
def __invoke(%(params)s):
|
||
|
%(func_body)s
|
||
|
return locals()
|
||
|
""" % {'cimports': '\n'.join(cimports),
|
||
|
'module_body': module_body,
|
||
|
'params': params,
|
||
|
'func_body': func_body }
|
||
|
for key, value in literals.items():
|
||
|
module_code = module_code.replace(key, value)
|
||
|
pyx_file = os.path.join(lib_dir, module_name + '.pyx')
|
||
|
fh = open(pyx_file, 'w')
|
||
|
try:
|
||
|
fh.write(module_code)
|
||
|
finally:
|
||
|
fh.close()
|
||
|
extension = Extension(
|
||
|
name = module_name,
|
||
|
sources = [pyx_file],
|
||
|
include_dirs = c_include_dirs,
|
||
|
extra_compile_args = cflags)
|
||
|
if build_extension is None:
|
||
|
build_extension = _get_build_extension()
|
||
|
build_extension.extensions = cythonize(
|
||
|
[extension],
|
||
|
include_path=cython_include_dirs or ['.'],
|
||
|
compiler_directives=cython_compiler_directives,
|
||
|
quiet=quiet)
|
||
|
build_extension.build_temp = os.path.dirname(pyx_file)
|
||
|
build_extension.build_lib = lib_dir
|
||
|
build_extension.run()
|
||
|
|
||
|
module = load_dynamic(module_name, module_path)
|
||
|
|
||
|
_cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
|
||
|
arg_list = [kwds[arg] for arg in arg_names]
|
||
|
return module.__invoke(*arg_list)
|
||
|
|
||
|
|
||
|
# Cached suffix used by cython_inline above. None should get
|
||
|
# overridden with actual value upon the first cython_inline invocation
|
||
|
cython_inline.so_ext = None
|
||
|
|
||
|
_find_non_space = re.compile('[^ ]').search
|
||
|
|
||
|
|
||
|
def strip_common_indent(code):
|
||
|
min_indent = None
|
||
|
lines = code.splitlines()
|
||
|
for line in lines:
|
||
|
match = _find_non_space(line)
|
||
|
if not match:
|
||
|
continue # blank
|
||
|
indent = match.start()
|
||
|
if line[indent] == '#':
|
||
|
continue # comment
|
||
|
if min_indent is None or min_indent > indent:
|
||
|
min_indent = indent
|
||
|
for ix, line in enumerate(lines):
|
||
|
match = _find_non_space(line)
|
||
|
if not match or not line or line[indent:indent+1] == '#':
|
||
|
continue
|
||
|
lines[ix] = line[min_indent:]
|
||
|
return '\n'.join(lines)
|
||
|
|
||
|
|
||
|
module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
|
||
|
def extract_func_code(code):
|
||
|
module = []
|
||
|
function = []
|
||
|
current = function
|
||
|
code = code.replace('\t', ' ')
|
||
|
lines = code.split('\n')
|
||
|
for line in lines:
|
||
|
if not line.startswith(' '):
|
||
|
if module_statement.match(line):
|
||
|
current = module
|
||
|
else:
|
||
|
current = function
|
||
|
current.append(line)
|
||
|
return '\n'.join(module), ' ' + '\n '.join(function)
|
||
|
|
||
|
|
||
|
try:
|
||
|
from inspect import getcallargs
|
||
|
except ImportError:
|
||
|
def getcallargs(func, *arg_values, **kwd_values):
|
||
|
all = {}
|
||
|
args, varargs, kwds, defaults = inspect.getargspec(func)
|
||
|
if varargs is not None:
|
||
|
all[varargs] = arg_values[len(args):]
|
||
|
for name, value in zip(args, arg_values):
|
||
|
all[name] = value
|
||
|
for name, value in list(kwd_values.items()):
|
||
|
if name in args:
|
||
|
if name in all:
|
||
|
raise TypeError("Duplicate argument %s" % name)
|
||
|
all[name] = kwd_values.pop(name)
|
||
|
if kwds is not None:
|
||
|
all[kwds] = kwd_values
|
||
|
elif kwd_values:
|
||
|
raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
|
||
|
if defaults is None:
|
||
|
defaults = ()
|
||
|
first_default = len(args) - len(defaults)
|
||
|
for ix, name in enumerate(args):
|
||
|
if name not in all:
|
||
|
if ix >= first_default:
|
||
|
all[name] = defaults[ix - first_default]
|
||
|
else:
|
||
|
raise TypeError("Missing argument: %s" % name)
|
||
|
return all
|
||
|
|
||
|
|
||
|
def get_body(source):
|
||
|
ix = source.index(':')
|
||
|
if source[:5] == 'lambda':
|
||
|
return "return %s" % source[ix+1:]
|
||
|
else:
|
||
|
return source[ix+1:]
|
||
|
|
||
|
|
||
|
# Lots to be done here... It would be especially cool if compiled functions
|
||
|
# could invoke each other quickly.
|
||
|
class RuntimeCompiledFunction(object):
|
||
|
|
||
|
def __init__(self, f):
|
||
|
self._f = f
|
||
|
self._body = get_body(inspect.getsource(f))
|
||
|
|
||
|
def __call__(self, *args, **kwds):
|
||
|
all = getcallargs(self._f, *args, **kwds)
|
||
|
if IS_PY3:
|
||
|
return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
|
||
|
else:
|
||
|
return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
|