| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- from rope.base import (change, taskhandle, evaluate,
- exceptions, pyobjects, pynames, ast)
- from rope.refactor import restructure, sourceutils, similarfinder, importutils
- class UseFunction(object):
- """Try to use a function wherever possible"""
- def __init__(self, project, resource, offset):
- self.project = project
- self.offset = offset
- this_pymodule = project.pycore.resource_to_pyobject(resource)
- pyname = evaluate.eval_location(this_pymodule, offset)
- if pyname is None:
- raise exceptions.RefactoringError('Unresolvable name selected')
- self.pyfunction = pyname.get_object()
- if not isinstance(self.pyfunction, pyobjects.PyFunction) or \
- not isinstance(self.pyfunction.parent, pyobjects.PyModule):
- raise exceptions.RefactoringError(
- 'Use function works for global functions, only.')
- self.resource = self.pyfunction.get_module().get_resource()
- self._check_returns()
- def _check_returns(self):
- node = self.pyfunction.get_ast()
- if _yield_count(node):
- raise exceptions.RefactoringError('Use function should not '
- 'be used on generators.')
- returns = _return_count(node)
- if returns > 1:
- raise exceptions.RefactoringError('usefunction: Function has more '
- 'than one return statement.')
- if returns == 1 and not _returns_last(node):
- raise exceptions.RefactoringError('usefunction: return should '
- 'be the last statement.')
- def get_changes(self, resources=None,
- task_handle=taskhandle.NullTaskHandle()):
- if resources is None:
- resources = self.project.pycore.get_python_files()
- changes = change.ChangeSet('Using function <%s>' %
- self.pyfunction.get_name())
- if self.resource in resources:
- newresources = list(resources)
- newresources.remove(self.resource)
- for c in self._restructure(newresources, task_handle).changes:
- changes.add_change(c)
- if self.resource in resources:
- for c in self._restructure([self.resource], task_handle,
- others=False).changes:
- changes.add_change(c)
- return changes
- def get_function_name(self):
- return self.pyfunction.get_name()
- def _restructure(self, resources, task_handle, others=True):
- body = self._get_body()
- pattern = self._make_pattern()
- goal = self._make_goal(import_=others)
- imports = None
- if others:
- imports = ['import %s' % self._module_name()]
- body_region = sourceutils.get_body_region(self.pyfunction)
- args_value = {'skip': (self.resource, body_region)}
- args = {'': args_value}
- restructuring = restructure.Restructure(
- self.project, pattern, goal, args=args, imports=imports)
- return restructuring.get_changes(resources=resources,
- task_handle=task_handle)
- def _find_temps(self):
- return find_temps(self.project, self._get_body())
- def _module_name(self):
- return self.project.pycore.modname(self.resource)
- def _make_pattern(self):
- params = self.pyfunction.get_param_names()
- body = self._get_body()
- body = restructure.replace(body, 'return', 'pass')
- wildcards = list(params)
- wildcards.extend(self._find_temps())
- if self._does_return():
- if self._is_expression():
- replacement = '${%s}' % self._rope_returned
- else:
- replacement = '%s = ${%s}' % (self._rope_result,
- self._rope_returned)
- body = restructure.replace(
- body, 'return ${%s}' % self._rope_returned,
- replacement)
- wildcards.append(self._rope_result)
- return similarfinder.make_pattern(body, wildcards)
- def _get_body(self):
- return sourceutils.get_body(self.pyfunction)
- def _make_goal(self, import_=False):
- params = self.pyfunction.get_param_names()
- function_name = self.pyfunction.get_name()
- if import_:
- function_name = self._module_name() + '.' + function_name
- goal = '%s(%s)' % (function_name,
- ', ' .join(('${%s}' % p) for p in params))
- if self._does_return() and not self._is_expression():
- goal = '${%s} = %s' % (self._rope_result, goal)
- return goal
- def _does_return(self):
- body = self._get_body()
- removed_return = restructure.replace(body, 'return ${result}', '')
- return removed_return != body
- def _is_expression(self):
- return len(self.pyfunction.get_ast().body) == 1
- _rope_result = '_rope__result'
- _rope_returned = '_rope__returned'
- def find_temps(project, code):
- code = 'def f():\n' + sourceutils.indent_lines(code, 4)
- pymodule = project.pycore.get_string_module(code)
- result = []
- function_scope = pymodule.get_scope().get_scopes()[0]
- for name, pyname in function_scope.get_names().items():
- if isinstance(pyname, pynames.AssignedName):
- result.append(name)
- return result
- def _returns_last(node):
- return node.body and isinstance(node.body[-1], ast.Return)
- def _yield_count(node):
- visitor = _ReturnOrYieldFinder()
- visitor.start_walking(node)
- return visitor.yields
- def _return_count(node):
- visitor = _ReturnOrYieldFinder()
- visitor.start_walking(node)
- return visitor.returns
- class _ReturnOrYieldFinder(object):
- def __init__(self):
- self.returns = 0
- self.yields = 0
- def _Return(self, node):
- self.returns += 1
- def _Yield(self, node):
- self.yields += 1
- def _FunctionDef(self, node):
- pass
- def _ClassDef(self, node):
- pass
- def start_walking(self, node):
- nodes = [node]
- if isinstance(node, ast.FunctionDef):
- nodes = ast.get_child_nodes(node)
- for child in nodes:
- ast.walk(child, self)
|