similarfinder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. """This module can be used for finding similar code"""
  2. import re
  3. import rope.refactor.wildcards
  4. from rope.base import codeanalyze, evaluate, exceptions, ast, builtins
  5. from rope.refactor import (patchedast, sourceutils, occurrences,
  6. wildcards, importutils)
  7. class BadNameInCheckError(exceptions.RefactoringError):
  8. pass
  9. class SimilarFinder(object):
  10. """`SimilarFinder` can be used to find similar pieces of code
  11. See the notes in the `rope.refactor.restructure` module for more
  12. info.
  13. """
  14. def __init__(self, pymodule, wildcards=None):
  15. """Construct a SimilarFinder"""
  16. self.source = pymodule.source_code
  17. self.raw_finder = RawSimilarFinder(
  18. pymodule.source_code, pymodule.get_ast(), self._does_match)
  19. self.pymodule = pymodule
  20. if wildcards is None:
  21. self.wildcards = {}
  22. for wildcard in [rope.refactor.wildcards.
  23. DefaultWildcard(pymodule.pycore.project)]:
  24. self.wildcards[wildcard.get_name()] = wildcard
  25. else:
  26. self.wildcards = wildcards
  27. def get_matches(self, code, args={}, start=0, end=None):
  28. self.args = args
  29. if end is None:
  30. end = len(self.source)
  31. skip_region = None
  32. if 'skip' in args.get('', {}):
  33. resource, region = args['']['skip']
  34. if resource == self.pymodule.get_resource():
  35. skip_region = region
  36. return self.raw_finder.get_matches(code, start=start, end=end,
  37. skip=skip_region)
  38. def get_match_regions(self, *args, **kwds):
  39. for match in self.get_matches(*args, **kwds):
  40. yield match.get_region()
  41. def _does_match(self, node, name):
  42. arg = self.args.get(name, '')
  43. kind = 'default'
  44. if isinstance(arg, (tuple, list)):
  45. kind = arg[0]
  46. arg = arg[1]
  47. suspect = wildcards.Suspect(self.pymodule, node, name)
  48. return self.wildcards[kind].matches(suspect, arg)
  49. class RawSimilarFinder(object):
  50. """A class for finding similar expressions and statements"""
  51. def __init__(self, source, node=None, does_match=None):
  52. if node is None:
  53. node = ast.parse(source)
  54. if does_match is None:
  55. self.does_match = self._simple_does_match
  56. else:
  57. self.does_match = does_match
  58. self._init_using_ast(node, source)
  59. def _simple_does_match(self, node, name):
  60. return isinstance(node, (ast.expr, ast.Name))
  61. def _init_using_ast(self, node, source):
  62. self.source = source
  63. self._matched_asts = {}
  64. if not hasattr(node, 'region'):
  65. patchedast.patch_ast(node, source)
  66. self.ast = node
  67. def get_matches(self, code, start=0, end=None, skip=None):
  68. """Search for `code` in source and return a list of `Match`\es
  69. `code` can contain wildcards. ``${name}`` matches normal
  70. names and ``${?name} can match any expression. You can use
  71. `Match.get_ast()` for getting the node that has matched a
  72. given pattern.
  73. """
  74. if end is None:
  75. end = len(self.source)
  76. for match in self._get_matched_asts(code):
  77. match_start, match_end = match.get_region()
  78. if start <= match_start and match_end <= end:
  79. if skip is not None and (skip[0] < match_end and
  80. skip[1] > match_start):
  81. continue
  82. yield match
  83. def _get_matched_asts(self, code):
  84. if code not in self._matched_asts:
  85. wanted = self._create_pattern(code)
  86. matches = _ASTMatcher(self.ast, wanted,
  87. self.does_match).find_matches()
  88. self._matched_asts[code] = matches
  89. return self._matched_asts[code]
  90. def _create_pattern(self, expression):
  91. expression = self._replace_wildcards(expression)
  92. node = ast.parse(expression)
  93. # Getting Module.Stmt.nodes
  94. nodes = node.body
  95. if len(nodes) == 1 and isinstance(nodes[0], ast.Expr):
  96. # Getting Discard.expr
  97. wanted = nodes[0].value
  98. else:
  99. wanted = nodes
  100. return wanted
  101. def _replace_wildcards(self, expression):
  102. ropevar = _RopeVariable()
  103. template = CodeTemplate(expression)
  104. mapping = {}
  105. for name in template.get_names():
  106. mapping[name] = ropevar.get_var(name)
  107. return template.substitute(mapping)
  108. class _ASTMatcher(object):
  109. def __init__(self, body, pattern, does_match):
  110. """Searches the given pattern in the body AST.
  111. body is an AST node and pattern can be either an AST node or
  112. a list of ASTs nodes
  113. """
  114. self.body = body
  115. self.pattern = pattern
  116. self.matches = None
  117. self.ropevar = _RopeVariable()
  118. self.matches_callback = does_match
  119. def find_matches(self):
  120. if self.matches is None:
  121. self.matches = []
  122. ast.call_for_nodes(self.body, self._check_node, recursive=True)
  123. return self.matches
  124. def _check_node(self, node):
  125. if isinstance(self.pattern, list):
  126. self._check_statements(node)
  127. else:
  128. self._check_expression(node)
  129. def _check_expression(self, node):
  130. mapping = {}
  131. if self._match_nodes(self.pattern, node, mapping):
  132. self.matches.append(ExpressionMatch(node, mapping))
  133. def _check_statements(self, node):
  134. for child in ast.get_children(node):
  135. if isinstance(child, (list, tuple)):
  136. self.__check_stmt_list(child)
  137. def __check_stmt_list(self, nodes):
  138. for index in range(len(nodes)):
  139. if len(nodes) - index >= len(self.pattern):
  140. current_stmts = nodes[index:index + len(self.pattern)]
  141. mapping = {}
  142. if self._match_stmts(current_stmts, mapping):
  143. self.matches.append(StatementMatch(current_stmts, mapping))
  144. def _match_nodes(self, expected, node, mapping):
  145. if isinstance(expected, ast.Name):
  146. if self.ropevar.is_var(expected.id):
  147. return self._match_wildcard(expected, node, mapping)
  148. if not isinstance(expected, ast.AST):
  149. return expected == node
  150. if expected.__class__ != node.__class__:
  151. return False
  152. children1 = self._get_children(expected)
  153. children2 = self._get_children(node)
  154. if len(children1) != len(children2):
  155. return False
  156. for child1, child2 in zip(children1, children2):
  157. if isinstance(child1, ast.AST):
  158. if not self._match_nodes(child1, child2, mapping):
  159. return False
  160. elif isinstance(child1, (list, tuple)):
  161. if not isinstance(child2, (list, tuple)) or \
  162. len(child1) != len(child2):
  163. return False
  164. for c1, c2 in zip(child1, child2):
  165. if not self._match_nodes(c1, c2, mapping):
  166. return False
  167. else:
  168. if child1 != child2:
  169. return False
  170. return True
  171. def _get_children(self, node):
  172. """Return not `ast.expr_context` children of `node`"""
  173. children = ast.get_children(node)
  174. return [child for child in children
  175. if not isinstance(child, ast.expr_context)]
  176. def _match_stmts(self, current_stmts, mapping):
  177. if len(current_stmts) != len(self.pattern):
  178. return False
  179. for stmt, expected in zip(current_stmts, self.pattern):
  180. if not self._match_nodes(expected, stmt, mapping):
  181. return False
  182. return True
  183. def _match_wildcard(self, node1, node2, mapping):
  184. name = self.ropevar.get_base(node1.id)
  185. if name not in mapping:
  186. if self.matches_callback(node2, name):
  187. mapping[name] = node2
  188. return True
  189. return False
  190. else:
  191. return self._match_nodes(mapping[name], node2, {})
  192. class Match(object):
  193. def __init__(self, mapping):
  194. self.mapping = mapping
  195. def get_region(self):
  196. """Returns match region"""
  197. def get_ast(self, name):
  198. """Return the ast node that has matched rope variables"""
  199. return self.mapping.get(name, None)
  200. class ExpressionMatch(Match):
  201. def __init__(self, ast, mapping):
  202. super(ExpressionMatch, self).__init__(mapping)
  203. self.ast = ast
  204. def get_region(self):
  205. return self.ast.region
  206. class StatementMatch(Match):
  207. def __init__(self, ast_list, mapping):
  208. super(StatementMatch, self).__init__(mapping)
  209. self.ast_list = ast_list
  210. def get_region(self):
  211. return self.ast_list[0].region[0], self.ast_list[-1].region[1]
  212. class CodeTemplate(object):
  213. def __init__(self, template):
  214. self.template = template
  215. self._find_names()
  216. def _find_names(self):
  217. self.names = {}
  218. for match in CodeTemplate._get_pattern().finditer(self.template):
  219. if 'name' in match.groupdict() and \
  220. match.group('name') is not None:
  221. start, end = match.span('name')
  222. name = self.template[start + 2:end - 1]
  223. if name not in self.names:
  224. self.names[name] = []
  225. self.names[name].append((start, end))
  226. def get_names(self):
  227. return self.names.keys()
  228. def substitute(self, mapping):
  229. collector = codeanalyze.ChangeCollector(self.template)
  230. for name, occurrences in self.names.items():
  231. for region in occurrences:
  232. collector.add_change(region[0], region[1], mapping[name])
  233. result = collector.get_changed()
  234. if result is None:
  235. return self.template
  236. return result
  237. _match_pattern = None
  238. @classmethod
  239. def _get_pattern(cls):
  240. if cls._match_pattern is None:
  241. pattern = codeanalyze.get_comment_pattern() + '|' + \
  242. codeanalyze.get_string_pattern() + '|' + \
  243. r'(?P<name>\$\{[^\s\$\}]*\})'
  244. cls._match_pattern = re.compile(pattern)
  245. return cls._match_pattern
  246. class _RopeVariable(object):
  247. """Transform and identify rope inserted wildcards"""
  248. _normal_prefix = '__rope__variable_normal_'
  249. _any_prefix = '__rope__variable_any_'
  250. def get_var(self, name):
  251. if name.startswith('?'):
  252. return self._get_any(name)
  253. else:
  254. return self._get_normal(name)
  255. def is_var(self, name):
  256. return self._is_normal(name) or self._is_var(name)
  257. def get_base(self, name):
  258. if self._is_normal(name):
  259. return name[len(self._normal_prefix):]
  260. if self._is_var(name):
  261. return '?' + name[len(self._any_prefix):]
  262. def _get_normal(self, name):
  263. return self._normal_prefix + name
  264. def _get_any(self, name):
  265. return self._any_prefix + name[1:]
  266. def _is_normal(self, name):
  267. return name.startswith(self._normal_prefix)
  268. def _is_var(self, name):
  269. return name.startswith(self._any_prefix)
  270. def make_pattern(code, variables):
  271. variables = set(variables)
  272. collector = codeanalyze.ChangeCollector(code)
  273. def does_match(node, name):
  274. return isinstance(node, ast.Name) and node.id == name
  275. finder = RawSimilarFinder(code, does_match=does_match)
  276. for variable in variables:
  277. for match in finder.get_matches('${%s}' % variable):
  278. start, end = match.get_region()
  279. collector.add_change(start, end, '${%s}' % variable)
  280. result = collector.get_changed()
  281. return result if result is not None else code
  282. def _pydefined_to_str(pydefined):
  283. address = []
  284. if isinstance(pydefined, (builtins.BuiltinClass, builtins.BuiltinFunction)):
  285. return '__builtins__.' + pydefined.get_name()
  286. else:
  287. while pydefined.parent is not None:
  288. address.insert(0, pydefined.get_name())
  289. pydefined = pydefined.parent
  290. module_name = pydefined.pycore.modname(pydefined.resource)
  291. return '.'.join(module_name.split('.') + address)