usefunction.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from rope.base import (change, taskhandle, evaluate,
  2. exceptions, pyobjects, pynames, ast)
  3. from rope.refactor import restructure, sourceutils, similarfinder, importutils
  4. class UseFunction(object):
  5. """Try to use a function wherever possible"""
  6. def __init__(self, project, resource, offset):
  7. self.project = project
  8. self.offset = offset
  9. this_pymodule = project.pycore.resource_to_pyobject(resource)
  10. pyname = evaluate.eval_location(this_pymodule, offset)
  11. if pyname is None:
  12. raise exceptions.RefactoringError('Unresolvable name selected')
  13. self.pyfunction = pyname.get_object()
  14. if not isinstance(self.pyfunction, pyobjects.PyFunction) or \
  15. not isinstance(self.pyfunction.parent, pyobjects.PyModule):
  16. raise exceptions.RefactoringError(
  17. 'Use function works for global functions, only.')
  18. self.resource = self.pyfunction.get_module().get_resource()
  19. self._check_returns()
  20. def _check_returns(self):
  21. node = self.pyfunction.get_ast()
  22. if _yield_count(node):
  23. raise exceptions.RefactoringError('Use function should not '
  24. 'be used on generators.')
  25. returns = _return_count(node)
  26. if returns > 1:
  27. raise exceptions.RefactoringError('usefunction: Function has more '
  28. 'than one return statement.')
  29. if returns == 1 and not _returns_last(node):
  30. raise exceptions.RefactoringError('usefunction: return should '
  31. 'be the last statement.')
  32. def get_changes(self, resources=None,
  33. task_handle=taskhandle.NullTaskHandle()):
  34. if resources is None:
  35. resources = self.project.pycore.get_python_files()
  36. changes = change.ChangeSet('Using function <%s>' %
  37. self.pyfunction.get_name())
  38. if self.resource in resources:
  39. newresources = list(resources)
  40. newresources.remove(self.resource)
  41. for c in self._restructure(newresources, task_handle).changes:
  42. changes.add_change(c)
  43. if self.resource in resources:
  44. for c in self._restructure([self.resource], task_handle,
  45. others=False).changes:
  46. changes.add_change(c)
  47. return changes
  48. def get_function_name(self):
  49. return self.pyfunction.get_name()
  50. def _restructure(self, resources, task_handle, others=True):
  51. body = self._get_body()
  52. pattern = self._make_pattern()
  53. goal = self._make_goal(import_=others)
  54. imports = None
  55. if others:
  56. imports = ['import %s' % self._module_name()]
  57. body_region = sourceutils.get_body_region(self.pyfunction)
  58. args_value = {'skip': (self.resource, body_region)}
  59. args = {'': args_value}
  60. restructuring = restructure.Restructure(
  61. self.project, pattern, goal, args=args, imports=imports)
  62. return restructuring.get_changes(resources=resources,
  63. task_handle=task_handle)
  64. def _find_temps(self):
  65. return find_temps(self.project, self._get_body())
  66. def _module_name(self):
  67. return self.project.pycore.modname(self.resource)
  68. def _make_pattern(self):
  69. params = self.pyfunction.get_param_names()
  70. body = self._get_body()
  71. body = restructure.replace(body, 'return', 'pass')
  72. wildcards = list(params)
  73. wildcards.extend(self._find_temps())
  74. if self._does_return():
  75. if self._is_expression():
  76. replacement = '${%s}' % self._rope_returned
  77. else:
  78. replacement = '%s = ${%s}' % (self._rope_result,
  79. self._rope_returned)
  80. body = restructure.replace(
  81. body, 'return ${%s}' % self._rope_returned,
  82. replacement)
  83. wildcards.append(self._rope_result)
  84. return similarfinder.make_pattern(body, wildcards)
  85. def _get_body(self):
  86. return sourceutils.get_body(self.pyfunction)
  87. def _make_goal(self, import_=False):
  88. params = self.pyfunction.get_param_names()
  89. function_name = self.pyfunction.get_name()
  90. if import_:
  91. function_name = self._module_name() + '.' + function_name
  92. goal = '%s(%s)' % (function_name,
  93. ', ' .join(('${%s}' % p) for p in params))
  94. if self._does_return() and not self._is_expression():
  95. goal = '${%s} = %s' % (self._rope_result, goal)
  96. return goal
  97. def _does_return(self):
  98. body = self._get_body()
  99. removed_return = restructure.replace(body, 'return ${result}', '')
  100. return removed_return != body
  101. def _is_expression(self):
  102. return len(self.pyfunction.get_ast().body) == 1
  103. _rope_result = '_rope__result'
  104. _rope_returned = '_rope__returned'
  105. def find_temps(project, code):
  106. code = 'def f():\n' + sourceutils.indent_lines(code, 4)
  107. pymodule = project.pycore.get_string_module(code)
  108. result = []
  109. function_scope = pymodule.get_scope().get_scopes()[0]
  110. for name, pyname in function_scope.get_names().items():
  111. if isinstance(pyname, pynames.AssignedName):
  112. result.append(name)
  113. return result
  114. def _returns_last(node):
  115. return node.body and isinstance(node.body[-1], ast.Return)
  116. def _yield_count(node):
  117. visitor = _ReturnOrYieldFinder()
  118. visitor.start_walking(node)
  119. return visitor.yields
  120. def _return_count(node):
  121. visitor = _ReturnOrYieldFinder()
  122. visitor.start_walking(node)
  123. return visitor.returns
  124. class _ReturnOrYieldFinder(object):
  125. def __init__(self):
  126. self.returns = 0
  127. self.yields = 0
  128. def _Return(self, node):
  129. self.returns += 1
  130. def _Yield(self, node):
  131. self.yields += 1
  132. def _FunctionDef(self, node):
  133. pass
  134. def _ClassDef(self, node):
  135. pass
  136. def start_walking(self, node):
  137. nodes = [node]
  138. if isinstance(node, ast.FunctionDef):
  139. nodes = ast.get_child_nodes(node)
  140. for child in nodes:
  141. ast.walk(child, self)