| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789 |
- import re
- from rope.base import ast, codeanalyze
- from rope.base.change import ChangeSet, ChangeContents
- from rope.base.exceptions import RefactoringError
- from rope.refactor import (sourceutils, similarfinder,
- patchedast, suites, usefunction)
- # Extract refactoring has lots of special cases. I tried to split it
- # to smaller parts to make it more manageable:
- #
- # _ExtractInfo: holds information about the refactoring; it is passed
- # to the parts that need to have information about the refactoring
- #
- # _ExtractCollector: merely saves all of the information necessary for
- # performing the refactoring.
- #
- # _DefinitionLocationFinder: finds where to insert the definition.
- #
- # _ExceptionalConditionChecker: checks for exceptional conditions in
- # which the refactoring cannot be applied.
- #
- # _ExtractMethodParts: generates the pieces of code (like definition)
- # needed for performing extract method.
- #
- # _ExtractVariableParts: like _ExtractMethodParts for variables.
- #
- # _ExtractPerformer: Uses above classes to collect refactoring
- # changes.
- #
- # There are a few more helper functions and classes used by above
- # classes.
- class _ExtractRefactoring(object):
- def __init__(self, project, resource, start_offset, end_offset,
- variable=False):
- self.project = project
- self.pycore = project.pycore
- self.resource = resource
- self.start_offset = self._fix_start(resource.read(), start_offset)
- self.end_offset = self._fix_end(resource.read(), end_offset)
- def _fix_start(self, source, offset):
- while offset < len(source) and source[offset].isspace():
- offset += 1
- return offset
- def _fix_end(self, source, offset):
- while offset > 0 and source[offset - 1].isspace():
- offset -= 1
- return offset
- def get_changes(self, extracted_name, similar=False, global_=False):
- """Get the changes this refactoring makes
- :parameters:
- - `similar`: if `True`, similar expressions/statements are also
- replaced.
- - `global_`: if `True`, the extracted method/variable will
- be global.
- """
- info = _ExtractInfo(
- self.project, self.resource, self.start_offset, self.end_offset,
- extracted_name, variable=self.kind == 'variable',
- similar=similar, make_global=global_)
- new_contents = _ExtractPerformer(info).extract()
- changes = ChangeSet('Extract %s <%s>' % (self.kind,
- extracted_name))
- changes.add_change(ChangeContents(self.resource, new_contents))
- return changes
- class ExtractMethod(_ExtractRefactoring):
- def __init__(self, *args, **kwds):
- super(ExtractMethod, self).__init__(*args, **kwds)
- kind = 'method'
- class ExtractVariable(_ExtractRefactoring):
- def __init__(self, *args, **kwds):
- kwds = dict(kwds)
- kwds['variable'] = True
- super(ExtractVariable, self).__init__(*args, **kwds)
- kind = 'variable'
- class _ExtractInfo(object):
- """Holds information about the extract to be performed"""
- def __init__(self, project, resource, start, end, new_name,
- variable, similar, make_global):
- self.pycore = project.pycore
- self.resource = resource
- self.pymodule = self.pycore.resource_to_pyobject(resource)
- self.global_scope = self.pymodule.get_scope()
- self.source = self.pymodule.source_code
- self.lines = self.pymodule.lines
- self.new_name = new_name
- self.variable = variable
- self.similar = similar
- self._init_parts(start, end)
- self._init_scope()
- self.make_global = make_global
- def _init_parts(self, start, end):
- self.region = (self._choose_closest_line_end(start),
- self._choose_closest_line_end(end, end=True))
- start = self.logical_lines.logical_line_in(
- self.lines.get_line_number(self.region[0]))[0]
- end = self.logical_lines.logical_line_in(
- self.lines.get_line_number(self.region[1]))[1]
- self.region_lines = (start, end)
- self.lines_region = (self.lines.get_line_start(self.region_lines[0]),
- self.lines.get_line_end(self.region_lines[1]))
- @property
- def logical_lines(self):
- return self.pymodule.logical_lines
- def _init_scope(self):
- start_line = self.region_lines[0]
- scope = self.global_scope.get_inner_scope_for_line(start_line)
- if scope.get_kind() != 'Module' and scope.get_start() == start_line:
- scope = scope.parent
- self.scope = scope
- self.scope_region = self._get_scope_region(self.scope)
- def _get_scope_region(self, scope):
- return (self.lines.get_line_start(scope.get_start()),
- self.lines.get_line_end(scope.get_end()) + 1)
- def _choose_closest_line_end(self, offset, end=False):
- lineno = self.lines.get_line_number(offset)
- line_start = self.lines.get_line_start(lineno)
- line_end = self.lines.get_line_end(lineno)
- if self.source[line_start:offset].strip() == '':
- if end:
- return line_start - 1
- else:
- return line_start
- elif self.source[offset:line_end].strip() == '':
- return min(line_end, len(self.source))
- return offset
- @property
- def one_line(self):
- return self.region != self.lines_region and \
- (self.logical_lines.logical_line_in(self.region_lines[0]) ==
- self.logical_lines.logical_line_in(self.region_lines[1]))
- @property
- def global_(self):
- return self.scope.parent is None
- @property
- def method(self):
- return self.scope.parent is not None and \
- self.scope.parent.get_kind() == 'Class'
- @property
- def indents(self):
- return sourceutils.get_indents(self.pymodule.lines,
- self.region_lines[0])
- @property
- def scope_indents(self):
- if self.global_:
- return 0
- return sourceutils.get_indents(self.pymodule.lines,
- self.scope.get_start())
- @property
- def extracted(self):
- return self.source[self.region[0]:self.region[1]]
- _returned = None
- @property
- def returned(self):
- """Does the extracted piece contain return statement"""
- if self._returned is None:
- node = _parse_text(self.extracted)
- self._returned = usefunction._returns_last(node)
- return self._returned
- class _ExtractCollector(object):
- """Collects information needed for performing the extract"""
- def __init__(self, info):
- self.definition = None
- self.body_pattern = None
- self.checks = {}
- self.replacement_pattern = None
- self.matches = None
- self.replacements = None
- self.definition_location = None
- class _ExtractPerformer(object):
- def __init__(self, info):
- self.info = info
- _ExceptionalConditionChecker()(self.info)
- def extract(self):
- extract_info = self._collect_info()
- content = codeanalyze.ChangeCollector(self.info.source)
- definition = extract_info.definition
- lineno, indents = extract_info.definition_location
- offset = self.info.lines.get_line_start(lineno)
- indented = sourceutils.fix_indentation(definition, indents)
- content.add_change(offset, offset, indented)
- self._replace_occurrences(content, extract_info)
- return content.get_changed()
- def _replace_occurrences(self, content, extract_info):
- for match in extract_info.matches:
- replacement = similarfinder.CodeTemplate(
- extract_info.replacement_pattern)
- mapping = {}
- for name in replacement.get_names():
- node = match.get_ast(name)
- if node:
- start, end = patchedast.node_region(match.get_ast(name))
- mapping[name] = self.info.source[start:end]
- else:
- mapping[name] = name
- region = match.get_region()
- content.add_change(region[0], region[1],
- replacement.substitute(mapping))
- def _collect_info(self):
- extract_collector = _ExtractCollector(self.info)
- self._find_definition(extract_collector)
- self._find_matches(extract_collector)
- self._find_definition_location(extract_collector)
- return extract_collector
- def _find_matches(self, collector):
- regions = self._where_to_search()
- finder = similarfinder.SimilarFinder(self.info.pymodule)
- matches = []
- for start, end in regions:
- matches.extend((finder.get_matches(collector.body_pattern,
- collector.checks, start, end)))
- collector.matches = matches
- def _where_to_search(self):
- if self.info.similar:
- if self.info.make_global or self.info.global_:
- return [(0, len(self.info.pymodule.source_code))]
- if self.info.method and not self.info.variable:
- class_scope = self.info.scope.parent
- regions = []
- method_kind = _get_function_kind(self.info.scope)
- for scope in class_scope.get_scopes():
- if method_kind == 'method' and \
- _get_function_kind(scope) != 'method':
- continue
- start = self.info.lines.get_line_start(scope.get_start())
- end = self.info.lines.get_line_end(scope.get_end())
- regions.append((start, end))
- return regions
- else:
- if self.info.variable:
- return [self.info.scope_region]
- else:
- return [self.info._get_scope_region(self.info.scope.parent)]
- else:
- return [self.info.region]
- def _find_definition_location(self, collector):
- matched_lines = []
- for match in collector.matches:
- start = self.info.lines.get_line_number(match.get_region()[0])
- start_line = self.info.logical_lines.logical_line_in(start)[0]
- matched_lines.append(start_line)
- location_finder = _DefinitionLocationFinder(self.info, matched_lines)
- collector.definition_location = (location_finder.find_lineno(),
- location_finder.find_indents())
- def _find_definition(self, collector):
- if self.info.variable:
- parts = _ExtractVariableParts(self.info)
- else:
- parts = _ExtractMethodParts(self.info)
- collector.definition = parts.get_definition()
- collector.body_pattern = parts.get_body_pattern()
- collector.replacement_pattern = parts.get_replacement_pattern()
- collector.checks = parts.get_checks()
- class _DefinitionLocationFinder(object):
- def __init__(self, info, matched_lines):
- self.info = info
- self.matched_lines = matched_lines
- # This only happens when subexpressions cannot be matched
- if not matched_lines:
- self.matched_lines.append(self.info.region_lines[0])
- def find_lineno(self):
- if self.info.variable and not self.info.make_global:
- return self._get_before_line()
- if self.info.make_global or self.info.global_:
- toplevel = self._find_toplevel(self.info.scope)
- ast = self.info.pymodule.get_ast()
- newlines = sorted(self.matched_lines + [toplevel.get_end() + 1])
- return suites.find_visible(ast, newlines)
- return self._get_after_scope()
- def _find_toplevel(self, scope):
- toplevel = scope
- if toplevel.parent is not None:
- while toplevel.parent.parent is not None:
- toplevel = toplevel.parent
- return toplevel
- def find_indents(self):
- if self.info.variable and not self.info.make_global:
- return sourceutils.get_indents(self.info.lines,
- self._get_before_line())
- else:
- if self.info.global_ or self.info.make_global:
- return 0
- return self.info.scope_indents
- def _get_before_line(self):
- ast = self.info.scope.pyobject.get_ast()
- return suites.find_visible(ast, self.matched_lines)
- def _get_after_scope(self):
- return self.info.scope.get_end() + 1
- class _ExceptionalConditionChecker(object):
- def __call__(self, info):
- self.base_conditions(info)
- if info.one_line:
- self.one_line_conditions(info)
- else:
- self.multi_line_conditions(info)
- def base_conditions(self, info):
- if info.region[1] > info.scope_region[1]:
- raise RefactoringError('Bad region selected for extract method')
- end_line = info.region_lines[1]
- end_scope = info.global_scope.get_inner_scope_for_line(end_line)
- if end_scope != info.scope and end_scope.get_end() != end_line:
- raise RefactoringError('Bad region selected for extract method')
- try:
- extracted = info.source[info.region[0]:info.region[1]]
- if info.one_line:
- extracted = '(%s)' % extracted
- if _UnmatchedBreakOrContinueFinder.has_errors(extracted):
- raise RefactoringError('A break/continue without having a '
- 'matching for/while loop.')
- except SyntaxError:
- raise RefactoringError('Extracted piece should '
- 'contain complete statements.')
- def one_line_conditions(self, info):
- if self._is_region_on_a_word(info):
- raise RefactoringError('Should extract complete statements.')
- if info.variable and not info.one_line:
- raise RefactoringError('Extract variable should not '
- 'span multiple lines.')
- def multi_line_conditions(self, info):
- node = _parse_text(info.source[info.region[0]:info.region[1]])
- count = usefunction._return_count(node)
- if count > 1:
- raise RefactoringError('Extracted piece can have only one '
- 'return statement.')
- if usefunction._yield_count(node):
- raise RefactoringError('Extracted piece cannot '
- 'have yield statements.')
- if count == 1 and not usefunction._returns_last(node):
- raise RefactoringError('Return should be the last statement.')
- if info.region != info.lines_region:
- raise RefactoringError('Extracted piece should '
- 'contain complete statements.')
- def _is_region_on_a_word(self, info):
- if info.region[0] > 0 and self._is_on_a_word(info, info.region[0] - 1) or \
- self._is_on_a_word(info, info.region[1] - 1):
- return True
- def _is_on_a_word(self, info, offset):
- prev = info.source[offset]
- if not (prev.isalnum() or prev == '_') or \
- offset + 1 == len(info.source):
- return False
- next = info.source[offset + 1]
- return next.isalnum() or next == '_'
- class _ExtractMethodParts(object):
- def __init__(self, info):
- self.info = info
- self.info_collector = self._create_info_collector()
- def get_definition(self):
- if self.info.global_:
- return '\n%s\n' % self._get_function_definition()
- else:
- return '\n%s' % self._get_function_definition()
- def get_replacement_pattern(self):
- variables = []
- variables.extend(self._find_function_arguments())
- variables.extend(self._find_function_returns())
- return similarfinder.make_pattern(self._get_call(), variables)
- def get_body_pattern(self):
- variables = []
- variables.extend(self._find_function_arguments())
- variables.extend(self._find_function_returns())
- variables.extend(self._find_temps())
- return similarfinder.make_pattern(self._get_body(), variables)
- def _get_body(self):
- result = sourceutils.fix_indentation(self.info.extracted, 0)
- if self.info.one_line:
- result = '(%s)' % result
- return result
- def _find_temps(self):
- return usefunction.find_temps(self.info.pycore.project,
- self._get_body())
- def get_checks(self):
- if self.info.method and not self.info.make_global:
- if _get_function_kind(self.info.scope) == 'method':
- class_name = similarfinder._pydefined_to_str(
- self.info.scope.parent.pyobject)
- return {self._get_self_name(): 'type=' + class_name}
- return {}
- def _create_info_collector(self):
- zero = self.info.scope.get_start() - 1
- start_line = self.info.region_lines[0] - zero
- end_line = self.info.region_lines[1] - zero
- info_collector = _FunctionInformationCollector(start_line, end_line,
- self.info.global_)
- body = self.info.source[self.info.scope_region[0]:
- self.info.scope_region[1]]
- node = _parse_text(body)
- ast.walk(node, info_collector)
- return info_collector
- def _get_function_definition(self):
- args = self._find_function_arguments()
- returns = self._find_function_returns()
- result = []
- if self.info.method and not self.info.make_global and \
- _get_function_kind(self.info.scope) != 'method':
- result.append('@staticmethod\n')
- result.append('def %s:\n' % self._get_function_signature(args))
- unindented_body = self._get_unindented_function_body(returns)
- indents = sourceutils.get_indent(self.info.pycore)
- function_body = sourceutils.indent_lines(unindented_body, indents)
- result.append(function_body)
- definition = ''.join(result)
- return definition + '\n'
- def _get_function_signature(self, args):
- args = list(args)
- prefix = ''
- if self._extracting_method():
- self_name = self._get_self_name()
- if self_name is None:
- raise RefactoringError('Extracting a method from a function '
- 'with no self argument.')
- if self_name in args:
- args.remove(self_name)
- args.insert(0, self_name)
- return prefix + self.info.new_name + \
- '(%s)' % self._get_comma_form(args)
- def _extracting_method(self):
- return self.info.method and not self.info.make_global and \
- _get_function_kind(self.info.scope) == 'method'
- def _get_self_name(self):
- param_names = self.info.scope.pyobject.get_param_names()
- if param_names:
- return param_names[0]
- def _get_function_call(self, args):
- prefix = ''
- if self.info.method and not self.info.make_global:
- if _get_function_kind(self.info.scope) == 'method':
- self_name = self._get_self_name()
- if self_name in args:
- args.remove(self_name)
- prefix = self_name + '.'
- else:
- prefix = self.info.scope.parent.pyobject.get_name() + '.'
- return prefix + '%s(%s)' % (self.info.new_name,
- self._get_comma_form(args))
- def _get_comma_form(self, names):
- result = ''
- if names:
- result += names[0]
- for name in names[1:]:
- result += ', ' + name
- return result
- def _get_call(self):
- if self.info.one_line:
- args = self._find_function_arguments()
- return self._get_function_call(args)
- args = self._find_function_arguments()
- returns = self._find_function_returns()
- call_prefix = ''
- if returns:
- call_prefix = self._get_comma_form(returns) + ' = '
- if self.info.returned:
- call_prefix = 'return '
- return call_prefix + self._get_function_call(args)
- def _find_function_arguments(self):
- # if not make_global, do not pass any global names; they are
- # all visible.
- if self.info.global_ and not self.info.make_global:
- return ()
- if not self.info.one_line:
- result = (self.info_collector.prewritten &
- self.info_collector.read)
- result |= (self.info_collector.prewritten &
- self.info_collector.postread &
- (self.info_collector.maybe_written -
- self.info_collector.written))
- return list(result)
- start = self.info.region[0]
- if start == self.info.lines_region[0]:
- start = start + re.search('\S', self.info.extracted).start()
- function_definition = self.info.source[start:self.info.region[1]]
- read = _VariableReadsAndWritesFinder.find_reads_for_one_liners(
- function_definition)
- return list(self.info_collector.prewritten.intersection(read))
- def _find_function_returns(self):
- if self.info.one_line or self.info.returned:
- return []
- written = self.info_collector.written | \
- self.info_collector.maybe_written
- return list(written & self.info_collector.postread)
- def _get_unindented_function_body(self, returns):
- if self.info.one_line:
- return 'return ' + _join_lines(self.info.extracted)
- extracted_body = self.info.extracted
- unindented_body = sourceutils.fix_indentation(extracted_body, 0)
- if returns:
- unindented_body += '\nreturn %s' % self._get_comma_form(returns)
- return unindented_body
- class _ExtractVariableParts(object):
- def __init__(self, info):
- self.info = info
- def get_definition(self):
- result = self.info.new_name + ' = ' + \
- _join_lines(self.info.extracted) + '\n'
- return result
- def get_body_pattern(self):
- return '(%s)' % self.info.extracted.strip()
- def get_replacement_pattern(self):
- return self.info.new_name
- def get_checks(self):
- return {}
- class _FunctionInformationCollector(object):
- def __init__(self, start, end, is_global):
- self.start = start
- self.end = end
- self.is_global = is_global
- self.prewritten = set()
- self.maybe_written = set()
- self.written = set()
- self.read = set()
- self.postread = set()
- self.postwritten = set()
- self.host_function = True
- self.conditional = False
- def _read_variable(self, name, lineno):
- if self.start <= lineno <= self.end:
- if name not in self.written:
- self.read.add(name)
- if self.end < lineno:
- if name not in self.postwritten:
- self.postread.add(name)
- def _written_variable(self, name, lineno):
- if self.start <= lineno <= self.end:
- if self.conditional:
- self.maybe_written.add(name)
- else:
- self.written.add(name)
- if self.start > lineno:
- self.prewritten.add(name)
- if self.end < lineno:
- self.postwritten.add(name)
- def _FunctionDef(self, node):
- if not self.is_global and self.host_function:
- self.host_function = False
- for name in _get_argnames(node.args):
- self._written_variable(name, node.lineno)
- for child in node.body:
- ast.walk(child, self)
- else:
- self._written_variable(node.name, node.lineno)
- visitor = _VariableReadsAndWritesFinder()
- for child in node.body:
- ast.walk(child, visitor)
- for name in visitor.read - visitor.written:
- self._read_variable(name, node.lineno)
- def _Name(self, node):
- if isinstance(node.ctx, (ast.Store, ast.AugStore)):
- self._written_variable(node.id, node.lineno)
- if not isinstance(node.ctx, ast.Store):
- self._read_variable(node.id, node.lineno)
- def _Assign(self, node):
- ast.walk(node.value, self)
- for child in node.targets:
- ast.walk(child, self)
- def _ClassDef(self, node):
- self._written_variable(node.name, node.lineno)
- def _handle_conditional_node(self, node):
- self.conditional = True
- try:
- for child in ast.get_child_nodes(node):
- ast.walk(child, self)
- finally:
- self.conditional = False
- def _If(self, node):
- self._handle_conditional_node(node)
- def _While(self, node):
- self._handle_conditional_node(node)
- def _For(self, node):
- self._handle_conditional_node(node)
- def _get_argnames(arguments):
- result = [node.id for node in arguments.args
- if isinstance(node, ast.Name)]
- if arguments.vararg:
- result.append(arguments.vararg)
- if arguments.kwarg:
- result.append(arguments.kwarg)
- return result
- class _VariableReadsAndWritesFinder(object):
- def __init__(self):
- self.written = set()
- self.read = set()
- def _Name(self, node):
- if isinstance(node.ctx, (ast.Store, ast.AugStore)):
- self.written.add(node.id)
- if not isinstance(node, ast.Store):
- self.read.add(node.id)
- def _FunctionDef(self, node):
- self.written.add(node.name)
- visitor = _VariableReadsAndWritesFinder()
- for child in ast.get_child_nodes(node):
- ast.walk(child, visitor)
- self.read.update(visitor.read - visitor.written)
- def _Class(self, node):
- self.written.add(node.name)
- @staticmethod
- def find_reads_and_writes(code):
- if code.strip() == '':
- return set(), set()
- if isinstance(code, unicode):
- code = code.encode('utf-8')
- node = _parse_text(code)
- visitor = _VariableReadsAndWritesFinder()
- ast.walk(node, visitor)
- return visitor.read, visitor.written
- @staticmethod
- def find_reads_for_one_liners(code):
- if code.strip() == '':
- return set(), set()
- node = _parse_text(code)
- visitor = _VariableReadsAndWritesFinder()
- ast.walk(node, visitor)
- return visitor.read
- class _UnmatchedBreakOrContinueFinder(object):
- def __init__(self):
- self.error = False
- self.loop_count = 0
- def _For(self, node):
- self.loop_encountered(node)
- def _While(self, node):
- self.loop_encountered(node)
- def loop_encountered(self, node):
- self.loop_count += 1
- for child in node.body:
- ast.walk(child, self)
- self.loop_count -= 1
- if node.orelse:
- ast.walk(node.orelse, self)
- def _Break(self, node):
- self.check_loop()
- def _Continue(self, node):
- self.check_loop()
- def check_loop(self):
- if self.loop_count < 1:
- self.error = True
- def _FunctionDef(self, node):
- pass
- def _ClassDef(self, node):
- pass
- @staticmethod
- def has_errors(code):
- if code.strip() == '':
- return False
- node = _parse_text(code)
- visitor = _UnmatchedBreakOrContinueFinder()
- ast.walk(node, visitor)
- return visitor.error
- def _get_function_kind(scope):
- return scope.pyobject.get_kind()
- def _parse_text(body):
- body = sourceutils.fix_indentation(body, 0)
- node = ast.parse(body)
- return node
- def _join_lines(code):
- lines = []
- for line in code.splitlines():
- if line.endswith('\\'):
- lines.append(line[:-1].strip())
- else:
- lines.append(line.strip())
- return ' '.join(lines)
|