module_imports.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import rope.base.pynames
  2. from rope.base import ast, utils
  3. from rope.refactor.importutils import importinfo
  4. from rope.refactor.importutils import actions
  5. class ModuleImports(object):
  6. def __init__(self, pycore, pymodule, import_filter=None):
  7. self.pycore = pycore
  8. self.pymodule = pymodule
  9. self.separating_lines = 0
  10. self.filter = import_filter
  11. @property
  12. @utils.saveit
  13. def imports(self):
  14. finder = _GlobalImportFinder(self.pymodule, self.pycore)
  15. result = finder.find_import_statements()
  16. self.separating_lines = finder.get_separating_line_count()
  17. if self.filter is not None:
  18. for import_stmt in result:
  19. if not self.filter(import_stmt):
  20. import_stmt.readonly = True
  21. return result
  22. def _get_unbound_names(self, defined_pyobject):
  23. visitor = _GlobalUnboundNameFinder(self.pymodule, defined_pyobject)
  24. ast.walk(self.pymodule.get_ast(), visitor)
  25. return visitor.unbound
  26. def remove_unused_imports(self):
  27. can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule))
  28. visitor = actions.RemovingVisitor(
  29. self.pycore, self._current_folder(), can_select)
  30. for import_statement in self.imports:
  31. import_statement.accept(visitor)
  32. def get_used_imports(self, defined_pyobject):
  33. result = []
  34. can_select = _OneTimeSelector(self._get_unbound_names(defined_pyobject))
  35. visitor = actions.FilteringVisitor(
  36. self.pycore, self._current_folder(), can_select)
  37. for import_statement in self.imports:
  38. new_import = import_statement.accept(visitor)
  39. if new_import is not None and not new_import.is_empty():
  40. result.append(new_import)
  41. return result
  42. def get_changed_source(self):
  43. imports = self.imports
  44. after_removing = self._remove_imports(imports)
  45. imports = [stmt for stmt in imports
  46. if not stmt.import_info.is_empty()]
  47. first_non_blank = self._first_non_blank_line(after_removing, 0)
  48. first_import = self._first_import_line() - 1
  49. result = []
  50. # Writing module docs
  51. result.extend(after_removing[first_non_blank:first_import])
  52. # Writing imports
  53. sorted_imports = sorted(imports, self._compare_import_locations)
  54. for stmt in sorted_imports:
  55. start = self._get_import_location(stmt)
  56. if stmt != sorted_imports[0]:
  57. result.append('\n' * stmt.blank_lines)
  58. result.append(stmt.get_import_statement() + '\n')
  59. if sorted_imports and first_non_blank < len(after_removing):
  60. result.append('\n' * self.separating_lines)
  61. # Writing the body
  62. first_after_imports = self._first_non_blank_line(after_removing,
  63. first_import)
  64. result.extend(after_removing[first_after_imports:])
  65. return ''.join(result)
  66. def _get_import_location(self, stmt):
  67. start = stmt.get_new_start()
  68. if start is None:
  69. start = stmt.get_old_location()[0]
  70. return start
  71. def _compare_import_locations(self, stmt1, stmt2):
  72. def get_location(stmt):
  73. if stmt.get_new_start() is not None:
  74. return stmt.get_new_start()
  75. else:
  76. return stmt.get_old_location()[0]
  77. return cmp(get_location(stmt1), get_location(stmt2))
  78. def _remove_imports(self, imports):
  79. lines = self.pymodule.source_code.splitlines(True)
  80. after_removing = []
  81. last_index = 0
  82. for stmt in imports:
  83. start, end = stmt.get_old_location()
  84. after_removing.extend(lines[last_index:start - 1])
  85. last_index = end - 1
  86. for i in range(start, end):
  87. after_removing.append('')
  88. after_removing.extend(lines[last_index:])
  89. return after_removing
  90. def _first_non_blank_line(self, lines, lineno):
  91. result = lineno
  92. for line in lines[lineno:]:
  93. if line.strip() == '':
  94. result += 1
  95. else:
  96. break
  97. return result
  98. def add_import(self, import_info):
  99. visitor = actions.AddingVisitor(self.pycore, [import_info])
  100. for import_statement in self.imports:
  101. if import_statement.accept(visitor):
  102. break
  103. else:
  104. lineno = self._get_new_import_lineno()
  105. blanks = self._get_new_import_blanks()
  106. self.imports.append(importinfo.ImportStatement(
  107. import_info, lineno, lineno,
  108. blank_lines=blanks))
  109. def _get_new_import_blanks(self):
  110. return 0
  111. def _get_new_import_lineno(self):
  112. if self.imports:
  113. return self.imports[-1].end_line
  114. return 1
  115. def filter_names(self, can_select):
  116. visitor = actions.RemovingVisitor(
  117. self.pycore, self._current_folder(), can_select)
  118. for import_statement in self.imports:
  119. import_statement.accept(visitor)
  120. def expand_stars(self):
  121. can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule))
  122. visitor = actions.ExpandStarsVisitor(
  123. self.pycore, self._current_folder(), can_select)
  124. for import_statement in self.imports:
  125. import_statement.accept(visitor)
  126. def remove_duplicates(self):
  127. added_imports = []
  128. for import_stmt in self.imports:
  129. visitor = actions.AddingVisitor(self.pycore,
  130. [import_stmt.import_info])
  131. for added_import in added_imports:
  132. if added_import.accept(visitor):
  133. import_stmt.empty_import()
  134. else:
  135. added_imports.append(import_stmt)
  136. def get_relative_to_absolute_list(self):
  137. visitor = rope.refactor.importutils.actions.RelativeToAbsoluteVisitor(
  138. self.pycore, self._current_folder())
  139. for import_stmt in self.imports:
  140. if not import_stmt.readonly:
  141. import_stmt.accept(visitor)
  142. return visitor.to_be_absolute
  143. def get_self_import_fix_and_rename_list(self):
  144. visitor = rope.refactor.importutils.actions.SelfImportVisitor(
  145. self.pycore, self._current_folder(), self.pymodule.get_resource())
  146. for import_stmt in self.imports:
  147. if not import_stmt.readonly:
  148. import_stmt.accept(visitor)
  149. return visitor.to_be_fixed, visitor.to_be_renamed
  150. def _current_folder(self):
  151. return self.pymodule.get_resource().parent
  152. def sort_imports(self):
  153. # IDEA: Sort from import list
  154. visitor = actions.SortingVisitor(self.pycore, self._current_folder())
  155. for import_statement in self.imports:
  156. import_statement.accept(visitor)
  157. in_projects = sorted(visitor.in_project, self._compare_imports)
  158. third_party = sorted(visitor.third_party, self._compare_imports)
  159. standards = sorted(visitor.standard, self._compare_imports)
  160. future = sorted(visitor.future, self._compare_imports)
  161. blank_lines = 0
  162. last_index = self._first_import_line()
  163. last_index = self._move_imports(future, last_index, 0)
  164. last_index = self._move_imports(standards, last_index, 1)
  165. last_index = self._move_imports(third_party, last_index, 1)
  166. last_index = self._move_imports(in_projects, last_index, 1)
  167. self.separating_lines = 2
  168. def _first_import_line(self):
  169. nodes = self.pymodule.get_ast().body
  170. lineno = 0
  171. if self.pymodule.get_doc() is not None:
  172. lineno = 1
  173. if len(nodes) > lineno:
  174. lineno = self.pymodule.logical_lines.logical_line_in(
  175. nodes[lineno].lineno)[0]
  176. else:
  177. lineno = self.pymodule.lines.length()
  178. while lineno > 1:
  179. line = self.pymodule.lines.get_line(lineno - 1)
  180. if line.strip() == '':
  181. lineno -= 1
  182. else:
  183. break
  184. return lineno
  185. def _compare_imports(self, stmt1, stmt2):
  186. str1 = stmt1.get_import_statement()
  187. str2 = stmt2.get_import_statement()
  188. if str1.startswith('from ') and not str2.startswith('from '):
  189. return 1
  190. if not str1.startswith('from ') and str2.startswith('from '):
  191. return -1
  192. return cmp(str1, str2)
  193. def _move_imports(self, imports, index, blank_lines):
  194. if imports:
  195. imports[0].move(index, blank_lines)
  196. index += 1
  197. if len(imports) > 1:
  198. for stmt in imports[1:]:
  199. stmt.move(index)
  200. index += 1
  201. return index
  202. def handle_long_imports(self, maxdots, maxlength):
  203. visitor = actions.LongImportVisitor(
  204. self._current_folder(), self.pycore, maxdots, maxlength)
  205. for import_statement in self.imports:
  206. if not import_statement.readonly:
  207. import_statement.accept(visitor)
  208. for import_info in visitor.new_imports:
  209. self.add_import(import_info)
  210. return visitor.to_be_renamed
  211. def remove_pyname(self, pyname):
  212. """Removes pyname when imported in ``from mod import x``"""
  213. visitor = actions.RemovePyNameVisitor(self.pycore, self.pymodule,
  214. pyname, self._current_folder())
  215. for import_stmt in self.imports:
  216. import_stmt.accept(visitor)
  217. class _OneTimeSelector(object):
  218. def __init__(self, names):
  219. self.names = names
  220. self.selected_names = set()
  221. def __call__(self, imported_primary):
  222. if self._can_name_be_added(imported_primary):
  223. for name in self._get_dotted_tokens(imported_primary):
  224. self.selected_names.add(name)
  225. return True
  226. return False
  227. def _get_dotted_tokens(self, imported_primary):
  228. tokens = imported_primary.split('.')
  229. for i in range(len(tokens)):
  230. yield '.'.join(tokens[:i + 1])
  231. def _can_name_be_added(self, imported_primary):
  232. for name in self._get_dotted_tokens(imported_primary):
  233. if name in self.names and name not in self.selected_names:
  234. return True
  235. return False
  236. class _UnboundNameFinder(object):
  237. def __init__(self, pyobject):
  238. self.pyobject = pyobject
  239. def _visit_child_scope(self, node):
  240. pyobject = self.pyobject.get_module().get_scope().\
  241. get_inner_scope_for_line(node.lineno).pyobject
  242. visitor = _LocalUnboundNameFinder(pyobject, self)
  243. for child in ast.get_child_nodes(node):
  244. ast.walk(child, visitor)
  245. def _FunctionDef(self, node):
  246. self._visit_child_scope(node)
  247. def _ClassDef(self, node):
  248. self._visit_child_scope(node)
  249. def _Name(self, node):
  250. if self._get_root()._is_node_interesting(node) and \
  251. not self.is_bound(node.id):
  252. self.add_unbound(node.id)
  253. def _Attribute(self, node):
  254. result = []
  255. while isinstance(node, ast.Attribute):
  256. result.append(node.attr)
  257. node = node.value
  258. if isinstance(node, ast.Name):
  259. result.append(node.id)
  260. primary = '.'.join(reversed(result))
  261. if self._get_root()._is_node_interesting(node) and \
  262. not self.is_bound(primary):
  263. self.add_unbound(primary)
  264. else:
  265. ast.walk(node, self)
  266. def _get_root(self):
  267. pass
  268. def is_bound(self, name, propagated=False):
  269. pass
  270. def add_unbound(self, name):
  271. pass
  272. class _GlobalUnboundNameFinder(_UnboundNameFinder):
  273. def __init__(self, pymodule, wanted_pyobject):
  274. super(_GlobalUnboundNameFinder, self).__init__(pymodule)
  275. self.unbound = set()
  276. self.names = set()
  277. for name, pyname in pymodule._get_structural_attributes().items():
  278. if not isinstance(pyname, (rope.base.pynames.ImportedName,
  279. rope.base.pynames.ImportedModule)):
  280. self.names.add(name)
  281. wanted_scope = wanted_pyobject.get_scope()
  282. self.start = wanted_scope.get_start()
  283. self.end = wanted_scope.get_end() + 1
  284. def _get_root(self):
  285. return self
  286. def is_bound(self, primary, propagated=False):
  287. name = primary.split('.')[0]
  288. if name in self.names:
  289. return True
  290. return False
  291. def add_unbound(self, name):
  292. names = name.split('.')
  293. for i in range(len(names)):
  294. self.unbound.add('.'.join(names[:i + 1]))
  295. def _is_node_interesting(self, node):
  296. return self.start <= node.lineno < self.end
  297. class _LocalUnboundNameFinder(_UnboundNameFinder):
  298. def __init__(self, pyobject, parent):
  299. super(_LocalUnboundNameFinder, self).__init__(pyobject)
  300. self.parent = parent
  301. def _get_root(self):
  302. return self.parent._get_root()
  303. def is_bound(self, primary, propagated=False):
  304. name = primary.split('.')[0]
  305. if propagated:
  306. names = self.pyobject.get_scope().get_propagated_names()
  307. else:
  308. names = self.pyobject.get_scope().get_names()
  309. if name in names or self.parent.is_bound(name, propagated=True):
  310. return True
  311. return False
  312. def add_unbound(self, name):
  313. self.parent.add_unbound(name)
  314. class _GlobalImportFinder(object):
  315. def __init__(self, pymodule, pycore):
  316. self.current_folder = None
  317. if pymodule.get_resource():
  318. self.current_folder = pymodule.get_resource().parent
  319. self.pymodule = pymodule
  320. self.pycore = pycore
  321. self.imports = []
  322. self.pymodule = pymodule
  323. self.lines = self.pymodule.lines
  324. def visit_import(self, node, end_line):
  325. start_line = node.lineno
  326. import_statement = importinfo.ImportStatement(
  327. importinfo.NormalImport(self._get_names(node.names)),
  328. start_line, end_line, self._get_text(start_line, end_line),
  329. blank_lines=self._count_empty_lines_before(start_line))
  330. self.imports.append(import_statement)
  331. def _count_empty_lines_before(self, lineno):
  332. result = 0
  333. for current in range(lineno - 1, 0, -1):
  334. line = self.lines.get_line(current)
  335. if line.strip() == '':
  336. result += 1
  337. else:
  338. break
  339. return result
  340. def _count_empty_lines_after(self, lineno):
  341. result = 0
  342. for current in range(lineno + 1, self.lines.length()):
  343. line = self.lines.get_line(current)
  344. if line.strip() == '':
  345. result += 1
  346. else:
  347. break
  348. return result
  349. def get_separating_line_count(self):
  350. if not self.imports:
  351. return 0
  352. return self._count_empty_lines_after(self.imports[-1].end_line - 1)
  353. def _get_text(self, start_line, end_line):
  354. result = []
  355. for index in range(start_line, end_line):
  356. result.append(self.lines.get_line(index))
  357. return '\n'.join(result)
  358. def visit_from(self, node, end_line):
  359. level = 0
  360. if node.level:
  361. level = node.level
  362. import_info = importinfo.FromImport(
  363. node.module or '', # see comment at rope.base.ast.walk
  364. level, self._get_names(node.names))
  365. start_line = node.lineno
  366. self.imports.append(importinfo.ImportStatement(
  367. import_info, node.lineno, end_line,
  368. self._get_text(start_line, end_line),
  369. blank_lines=self._count_empty_lines_before(start_line)))
  370. def _get_names(self, alias_names):
  371. result = []
  372. for alias in alias_names:
  373. result.append((alias.name, alias.asname))
  374. return result
  375. def find_import_statements(self):
  376. nodes = self.pymodule.get_ast().body
  377. for index, node in enumerate(nodes):
  378. if isinstance(node, (ast.Import, ast.ImportFrom)):
  379. lines = self.pymodule.logical_lines
  380. end_line = lines.logical_line_in(node.lineno)[1] + 1
  381. if isinstance(node, ast.Import):
  382. self.visit_import(node, end_line)
  383. if isinstance(node, ast.ImportFrom):
  384. self.visit_from(node, end_line)
  385. return self.imports