From 0a5b393b15727f8671be283cfcef4da6e8fefdd6 Mon Sep 17 00:00:00 2001 From: EhteshamSid Date: Sun, 5 Apr 2026 16:19:46 -0400 Subject: [PATCH] Use safe_load in yamlhelper --- awscli/bcdoc/textwriter.py | 221 +++++++++-------- .../cloudformation/yamlhelper.py | 11 +- awscli/customizations/eks/ordered_yaml.py | 25 +- awscli/customizations/emr/sshutils.py | 41 ++-- awscli/testutils.py | 232 +++++++++--------- tests/unit/customizations/emr/__init__.py | 11 +- 6 files changed, 279 insertions(+), 262 deletions(-) diff --git a/awscli/bcdoc/textwriter.py b/awscli/bcdoc/textwriter.py index 6fc171b9b475..87dd6e51c352 100644 --- a/awscli/bcdoc/textwriter.py +++ b/awscli/bcdoc/textwriter.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- """ - Custom docutils writer for plain text. - Based heavily on the Sphinx text writer. See copyright below. +Custom docutils writer for plain text. +Based heavily on the Sphinx text writer. See copyright below. - :copyright: Copyright 2007-2011 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. +:copyright: Copyright 2007-2011 by the Sphinx team, see AUTHORS. +:license: BSD, see LICENSE for details. """ + import os import re import textwrap @@ -19,10 +20,11 @@ class TextWrapper(textwrap.TextWrapper): """Custom subclass that uses a different word separator regex.""" wordsep_re = re.compile( - r'(\s+|' # any whitespace - r'(?<=\s)(?::[a-z-]+:)?`\S+|' # interpreted text start - r'[^\s\w]*\w+[a-zA-Z]-(?=\w+[a-zA-Z])|' # hyphenated words - r'(?<=[\w\!\"\'\&\.\,\?])-{2,}(?=\w))') # em-dash + r"(\s+|" # any whitespace + r"(?<=\s)(?::[a-z-]+:)?`\S+|" # interpreted text start + r"[^\s\w]*\w+[a-zA-Z]-(?=\w+[a-zA-Z])|" # hyphenated words + r"(?<=[\w\!\"\'\&\.\,\?])-{2,}(?=\w))" + ) # em-dash MAXWIDTH = 70 @@ -35,8 +37,8 @@ def my_wrap(text, width=MAXWIDTH, **kwargs): class TextWriter(writers.Writer): - supported = ('text',) - settings_spec = ('No options here.', '', ()) + supported = ("text",) + settings_spec = ("No options here.", "", ()) settings_defaults = {} output = None @@ -70,7 +72,9 @@ def new_state(self, indent=STDINDENT): self.states.append([]) self.stateindent.append(indent) - def end_state(self, wrap=True, end=[''], first=None): + def end_state(self, wrap=True, end=None, first=None): + if end is None: + end = [] content = self.states.pop() maxindent = sum(self.stateindent) indent = self.stateindent.pop() @@ -81,12 +85,13 @@ def do_format(): if not toformat: return if wrap: - res = my_wrap(''.join(toformat), width=MAXWIDTH-maxindent) + res = my_wrap("".join(toformat), width=MAXWIDTH - maxindent) else: - res = ''.join(toformat).splitlines() + res = "".join(toformat).splitlines() if end: res += end result.append((indent, res)) + for itemindent, item in content: if itemindent == -1: toformat.append(item) @@ -107,9 +112,11 @@ def visit_document(self, node): def depart_document(self, node): self.end_state() - self.body = self.nl.join(line and (' '*indent + line) - for indent, lines in self.states[0] - for line in lines) + self.body = self.nl.join( + line and (" " * indent + line) + for indent, lines in self.states[0] + for line in lines + ) # XXX header/footer? def visit_highlightlang(self, node): @@ -133,10 +140,10 @@ def depart_topic(self, node): def visit_rubric(self, node): self.new_state(0) - self.add_text('-[ ') + self.add_text("-[ ") def depart_rubric(self, node): - self.add_text(' ]-') + self.add_text(" ]-") self.end_state() def visit_compound(self, node): @@ -153,7 +160,7 @@ def depart_glossary(self, node): def visit_title(self, node): if isinstance(node.parent, nodes.Admonition): - self.add_text(node.astext()+': ') + self.add_text(node.astext() + ": ") raise nodes.SkipNode self.new_state(0) @@ -161,10 +168,10 @@ def depart_title(self, node): if isinstance(node.parent, nodes.section): char = self._title_char else: - char = '^' - text = ''.join(x[1] for x in self.states.pop() if x[0] == -1) + char = "^" + text = "".join(x[1] for x in self.states.pop() if x[0] == -1) self.stateindent.pop() - self.states[-1].append((0, ['', text, '%s' % (char * len(text)), ''])) + self.states[-1].append((0, ["", text, "%s" % (char * len(text)), ""])) def visit_subtitle(self, node): pass @@ -173,7 +180,7 @@ def depart_subtitle(self, node): pass def visit_attribution(self, node): - self.add_text('-- ') + self.add_text("-- ") def depart_attribution(self, node): pass @@ -186,8 +193,8 @@ def depart_desc(self, node): def visit_desc_signature(self, node): self.new_state(0) - if node.parent['objtype'] in ('class', 'exception'): - self.add_text('%s ' % node.parent['objtype']) + if node.parent["objtype"] in ("class", "exception"): + self.add_text("%s " % node.parent["objtype"]) def depart_desc_signature(self, node): # XXX: wrap signatures in a way that makes sense @@ -212,31 +219,31 @@ def depart_desc_type(self, node): pass def visit_desc_returns(self, node): - self.add_text(' -> ') + self.add_text(" -> ") def depart_desc_returns(self, node): pass def visit_desc_parameterlist(self, node): - self.add_text('(') + self.add_text("(") self.first_param = 1 def depart_desc_parameterlist(self, node): - self.add_text(')') + self.add_text(")") def visit_desc_parameter(self, node): if not self.first_param: - self.add_text(', ') + self.add_text(", ") else: self.first_param = 0 self.add_text(node.astext()) raise nodes.SkipNode def visit_desc_optional(self, node): - self.add_text('[') + self.add_text("[") def depart_desc_optional(self, node): - self.add_text(']') + self.add_text("]") def visit_desc_annotation(self, node): pass @@ -273,14 +280,14 @@ def visit_productionlist(self, node): self.new_state() names = [] for production in node: - names.append(production['tokenname']) + names.append(production["tokenname"]) maxlen = max(len(name) for name in names) for production in node: - if production['tokenname']: - self.add_text(production['tokenname'].ljust(maxlen) + ' ::=') - lastname = production['tokenname'] + if production["tokenname"]: + self.add_text(production["tokenname"].ljust(maxlen) + " ::=") + lastname = production["tokenname"] else: - self.add_text('%s ' % (' '*len(lastname))) + self.add_text("%s " % (" " * len(lastname))) self.add_text(production.astext() + self.nl) self.end_state(wrap=False) raise nodes.SkipNode @@ -289,24 +296,24 @@ def visit_seealso(self, node): self.new_state() def depart_seealso(self, node): - self.end_state(first='') + self.end_state(first="") def visit_footnote(self, node): self._footnote = node.children[0].astext().strip() self.new_state(len(self._footnote) + 3) def depart_footnote(self, node): - self.end_state(first='[%s] ' % self._footnote) + self.end_state(first="[%s] " % self._footnote) def visit_citation(self, node): if len(node) and isinstance(node[0], nodes.label): self._citlabel = node[0].astext() else: - self._citlabel = '' + self._citlabel = "" self.new_state(len(self._citlabel) + 3) def depart_citation(self, node): - self.end_state(first='[%s] ' % self._citlabel) + self.end_state(first="[%s] " % self._citlabel) def visit_label(self, node): raise nodes.SkipNode @@ -329,13 +336,13 @@ def visit_option_group(self, node): self._firstoption = True def depart_option_group(self, node): - self.add_text(' ') + self.add_text(" ") def visit_option(self, node): if self._firstoption: self._firstoption = False else: - self.add_text(', ') + self.add_text(", ") def depart_option(self, node): pass @@ -347,7 +354,7 @@ def depart_option_string(self, node): pass def visit_option_argument(self, node): - self.add_text(node['delimiter']) + self.add_text(node["delimiter"]) def depart_option_argument(self, node): pass @@ -362,7 +369,7 @@ def visit_tabular_col_spec(self, node): raise nodes.SkipNode def visit_colspec(self, node): - self.table[0].append(node['colwidth']) + self.table[0].append(node["colwidth"]) raise nodes.SkipNode def visit_tgroup(self, node): @@ -378,7 +385,7 @@ def depart_thead(self, node): pass def visit_tbody(self, node): - self.table.append('sep') + self.table.append("sep") def depart_tbody(self, node): pass @@ -390,9 +397,10 @@ def depart_row(self, node): pass def visit_entry(self, node): - if 'morerows' in node or 'morecols' in node: - raise NotImplementedError('Column or row spanning cells are ' - 'not implemented.') + if "morerows" in node or "morecols" in node: + raise NotImplementedError( + "Column or row spanning cells are " "not implemented." + ) self.new_state(0) def depart_entry(self, node): @@ -402,7 +410,7 @@ def depart_entry(self, node): def visit_table(self, node): if self.table: - raise NotImplementedError('Nested tables are not supported.') + raise NotImplementedError("Nested tables are not supported.") self.new_state(0) self.table = [[]] @@ -414,7 +422,7 @@ def depart_table(self, node): separator = 0 # don't allow paragraphs in table cells for now for line in lines: - if line == 'sep': + if line == "sep": separator = len(fmted_rows) else: cells = [] @@ -428,52 +436,51 @@ def depart_table(self, node): cells.append(par) fmted_rows.append(cells) - def writesep(char='-'): - out = ['+'] + def writesep(char="-"): + out = ["+"] for width in realwidths: - out.append(char * (width+2)) - out.append('+') - self.add_text(''.join(out) + self.nl) + out.append(char * (width + 2)) + out.append("+") + self.add_text("".join(out) + self.nl) def writerow(row): lines = zip(*row) for line in lines: - out = ['|'] + out = ["|"] for i, cell in enumerate(line): if cell: - out.append(' ' + cell.ljust(realwidths[i]+1)) + out.append(" " + cell.ljust(realwidths[i] + 1)) else: - out.append(' ' * (realwidths[i] + 2)) - out.append('|') - self.add_text(''.join(out) + self.nl) + out.append(" " * (realwidths[i] + 2)) + out.append("|") + self.add_text("".join(out) + self.nl) for i, row in enumerate(fmted_rows): if separator and i == separator: - writesep('=') + writesep("=") else: - writesep('-') + writesep("-") writerow(row) - writesep('-') + writesep("-") self.table = None self.end_state(wrap=False) def visit_acks(self, node): self.new_state(0) - self.add_text( - ', '.join(n.astext() for n in node.children[0].children) + '.') + self.add_text(", ".join(n.astext() for n in node.children[0].children) + ".") self.end_state() raise nodes.SkipNode def visit_image(self, node): - if 'alt' in node.attributes: - self.add_text(_('[image: %s]') % node['alt']) - self.add_text(_('[image]')) + if "alt" in node.attributes: + self.add_text(_("[image: %s]") % node["alt"]) + self.add_text(_("[image]")) raise nodes.SkipNode def visit_transition(self, node): indent = sum(self.stateindent) self.new_state(0) - self.add_text('=' * (MAXWIDTH - indent)) + self.add_text("=" * (MAXWIDTH - indent)) self.end_state() raise nodes.SkipNode @@ -509,15 +516,16 @@ def visit_list_item(self, node): def depart_list_item(self, node): if self.list_counter[-1] == -1: - self.end_state(first='* ', end=None) + self.end_state(first="* ", end=None) elif self.list_counter[-1] == -2: pass else: - self.end_state(first='%s. ' % self.list_counter[-1], end=None) + self.end_state(first="%s. " % self.list_counter[-1], end=None) def visit_definition_list_item(self, node): - self._li_has_classifier = len(node) >= 2 and \ - isinstance(node[1], nodes.classifier) + self._li_has_classifier = len(node) >= 2 and isinstance( + node[1], nodes.classifier + ) def depart_definition_list_item(self, node): pass @@ -530,11 +538,11 @@ def depart_term(self, node): self.end_state(end=None) def visit_termsep(self, node): - self.add_text(', ') + self.add_text(", ") raise nodes.SkipNode def visit_classifier(self, node): - self.add_text(' : ') + self.add_text(" : ") def depart_classifier(self, node): self.end_state(end=None) @@ -561,7 +569,7 @@ def visit_field_name(self, node): self.new_state(0) def depart_field_name(self, node): - self.add_text(':') + self.add_text(":") self.end_state(end=None) def visit_field_body(self, node): @@ -670,35 +678,35 @@ def depart_download_reference(self, node): pass def visit_emphasis(self, node): - self.add_text('*') + self.add_text("*") def depart_emphasis(self, node): - self.add_text('*') + self.add_text("*") def visit_literal_emphasis(self, node): - self.add_text('*') + self.add_text("*") def depart_literal_emphasis(self, node): - self.add_text('*') + self.add_text("*") def visit_strong(self, node): - self.add_text('**') + self.add_text("**") def depart_strong(self, node): - self.add_text('**') + self.add_text("**") def visit_abbreviation(self, node): - self.add_text('') + self.add_text("") def depart_abbreviation(self, node): - if node.hasattr('explanation'): - self.add_text(' (%s)' % node['explanation']) + if node.hasattr("explanation"): + self.add_text(" (%s)" % node["explanation"]) def visit_title_reference(self, node): - self.add_text('*') + self.add_text("*") def depart_title_reference(self, node): - self.add_text('*') + self.add_text("*") def visit_literal(self, node): self.add_text('"') @@ -707,23 +715,23 @@ def depart_literal(self, node): self.add_text('"') def visit_subscript(self, node): - self.add_text('_') + self.add_text("_") def depart_subscript(self, node): pass def visit_superscript(self, node): - self.add_text('^') + self.add_text("^") def depart_superscript(self, node): pass def visit_footnote_reference(self, node): - self.add_text('[%s]' % node.astext()) + self.add_text("[%s]" % node.astext()) raise nodes.SkipNode def visit_citation_reference(self, node): - self.add_text('[%s]' % node.astext()) + self.add_text("[%s]" % node.astext()) raise nodes.SkipNode def visit_Text(self, node): @@ -745,14 +753,14 @@ def depart_inline(self, node): pass def visit_problematic(self, node): - self.add_text('>>') + self.add_text(">>") def depart_problematic(self, node): - self.add_text('<<') + self.add_text("<<") def visit_system_message(self, node): self.new_state(0) - self.add_text('' % node.astext()) + self.add_text("" % node.astext()) self.end_state() raise nodes.SkipNode @@ -764,7 +772,7 @@ def visit_meta(self, node): raise nodes.SkipNode def visit_raw(self, node): - if 'text' in node.get('format', '').split(): + if "text" in node.get("format", "").split(): self.body.append(node.astext()) raise nodes.SkipNode @@ -773,27 +781,28 @@ def _visit_admonition(self, node): def _make_depart_admonition(name): def depart_admonition(self, node): - self.end_state(first=name.capitalize() + ': ') + self.end_state(first=name.capitalize() + ": ") + return depart_admonition visit_attention = _visit_admonition - depart_attention = _make_depart_admonition('attention') + depart_attention = _make_depart_admonition("attention") visit_caution = _visit_admonition - depart_caution = _make_depart_admonition('caution') + depart_caution = _make_depart_admonition("caution") visit_danger = _visit_admonition - depart_danger = _make_depart_admonition('danger') + depart_danger = _make_depart_admonition("danger") visit_error = _visit_admonition - depart_error = _make_depart_admonition('error') + depart_error = _make_depart_admonition("error") visit_hint = _visit_admonition - depart_hint = _make_depart_admonition('hint') + depart_hint = _make_depart_admonition("hint") visit_important = _visit_admonition - depart_important = _make_depart_admonition('important') + depart_important = _make_depart_admonition("important") visit_note = _visit_admonition - depart_note = _make_depart_admonition('note') + depart_note = _make_depart_admonition("note") visit_tip = _visit_admonition - depart_tip = _make_depart_admonition('tip') + depart_tip = _make_depart_admonition("tip") visit_warning = _visit_admonition - depart_warning = _make_depart_admonition('warning') + depart_warning = _make_depart_admonition("warning") def unknown_visit(self, node): - raise NotImplementedError('Unknown node: ' + node.__class__.__name__) + raise NotImplementedError("Unknown node: " + node.__class__.__name__) diff --git a/awscli/customizations/cloudformation/yamlhelper.py b/awscli/customizations/cloudformation/yamlhelper.py index 61603603e669..12c354592b0f 100644 --- a/awscli/customizations/cloudformation/yamlhelper.py +++ b/awscli/customizations/cloudformation/yamlhelper.py @@ -79,11 +79,11 @@ def _dict_constructor(loader, node): class SafeLoaderWrapper(yaml.SafeLoader): - """Isolated safe loader to allow for customizations without global changes. - """ + """Isolated safe loader to allow for customizations without global changes.""" pass + def yaml_parse(yamlstr): """Parse a yaml string""" try: @@ -93,10 +93,11 @@ def yaml_parse(yamlstr): return json.loads(yamlstr, object_pairs_hook=OrderedDict) except ValueError: loader = SafeLoaderWrapper - loader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, - _dict_constructor) + loader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _dict_constructor + ) loader.add_multi_constructor("!", intrinsics_multi_constructor) - return yaml.load(yamlstr, loader) + return yaml.safe_load(yamlstr, loader) class FlattenAliasDumper(yaml.SafeDumper): diff --git a/awscli/customizations/eks/ordered_yaml.py b/awscli/customizations/eks/ordered_yaml.py index aacac1fcf16f..ce95de942608 100644 --- a/awscli/customizations/eks/ordered_yaml.py +++ b/awscli/customizations/eks/ordered_yaml.py @@ -16,35 +16,35 @@ class SafeOrderedLoader(yaml.SafeLoader): - """ Safely load a yaml file into an OrderedDict.""" + """Safely load a yaml file into an OrderedDict.""" class SafeOrderedDumper(yaml.SafeDumper): - """ Safely dump an OrderedDict as yaml.""" + """Safely dump an OrderedDict as yaml.""" def _ordered_constructor(loader, node): - loader.flatten_mapping(node) - return OrderedDict(loader.construct_pairs(node)) + loader.flatten_mapping(node) + return OrderedDict(loader.construct_pairs(node)) SafeOrderedLoader.add_constructor( - yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, - _ordered_constructor) + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _ordered_constructor +) def _ordered_representer(dumper, data): - return dumper.represent_mapping( - yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, - data.items()) + return dumper.represent_mapping( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items() + ) SafeOrderedDumper.add_representer(OrderedDict, _ordered_representer) def ordered_yaml_load(stream): - """ Load an OrderedDict object from a yaml stream.""" - return yaml.load(stream, SafeOrderedLoader) + """Load an OrderedDict object from a yaml stream.""" + return yaml.safe_load(stream, SafeOrderedLoader) def ordered_yaml_dump(to_dump, stream=None): @@ -58,5 +58,4 @@ def ordered_yaml_dump(to_dump, stream=None): If not given or if None, only return the value :type stream: file """ - return yaml.dump(to_dump, stream, - SafeOrderedDumper, default_flow_style=False) + return yaml.dump(to_dump, stream, SafeOrderedDumper, default_flow_style=False) diff --git a/awscli/customizations/emr/sshutils.py b/awscli/customizations/emr/sshutils.py index 443f64b472d0..7434c50a7ed9 100644 --- a/awscli/customizations/emr/sshutils.py +++ b/awscli/customizations/emr/sshutils.py @@ -31,8 +31,7 @@ def validate_and_find_master_dns(session, parsed_globals, cluster_id): Return the latest created master instance public dns name. Throw MasterDNSNotAvailableError or ClusterTerminatedError. """ - cluster_state = emrutils.get_cluster_state( - session, parsed_globals, cluster_id) + cluster_state = emrutils.get_cluster_state(session, parsed_globals, cluster_id) if cluster_state in constants.TERMINATED_STATES: raise exceptions.ClusterTerminatedError @@ -40,7 +39,7 @@ def validate_and_find_master_dns(session, parsed_globals, cluster_id): emr = emrutils.get_client(session, parsed_globals) try: - cluster_running_waiter = emr.get_waiter('cluster_running') + cluster_running_waiter = emr.get_waiter("cluster_running") if cluster_state in constants.STARTING_STATES: print("Waiting for the cluster to start.") cluster_running_waiter.wait(ClusterId=cluster_id) @@ -48,21 +47,25 @@ def validate_and_find_master_dns(session, parsed_globals, cluster_id): raise exceptions.MasterDNSNotAvailableError return emrutils.find_master_dns( - session=session, cluster_id=cluster_id, - parsed_globals=parsed_globals) + session=session, cluster_id=cluster_id, parsed_globals=parsed_globals + ) def validate_ssh_with_key_file(key_file): - if (emrutils.which('putty.exe') or emrutils.which('ssh') or - emrutils.which('ssh.exe')) is None: + if ( + emrutils.which("putty.exe") + or emrutils.which("ssh") + or emrutils.which("ssh.exe") + ) is None: raise exceptions.SSHNotFoundError else: check_ssh_key_format(key_file) def validate_scp_with_key_file(key_file): - if (emrutils.which('pscp.exe') or emrutils.which('scp') or - emrutils.which('scp.exe')) is None: + if ( + emrutils.which("pscp.exe") or emrutils.which("scp") or emrutils.which("scp.exe") + ) is None: raise exceptions.SCPNotFoundError else: check_scp_key_format(key_file) @@ -70,9 +73,11 @@ def validate_scp_with_key_file(key_file): def check_scp_key_format(key_file): # If only pscp is present and the file format is incorrect - if (emrutils.which('pscp.exe') is not None and - (emrutils.which('scp.exe') or emrutils.which('scp')) is None): - if check_command_key_format(key_file, ['ppk']) is False: + if ( + emrutils.which("pscp.exe") is not None + and (emrutils.which("scp.exe") or emrutils.which("scp")) is None + ): + if check_command_key_format(key_file, ["ppk"]) is False: raise exceptions.WrongPuttyKeyError else: pass @@ -80,15 +85,19 @@ def check_scp_key_format(key_file): def check_ssh_key_format(key_file): # If only putty is present and the file format is incorrect - if (emrutils.which('putty.exe') is not None and - (emrutils.which('ssh.exe') or emrutils.which('ssh')) is None): - if check_command_key_format(key_file, ['ppk']) is False: + if ( + emrutils.which("putty.exe") is not None + and (emrutils.which("ssh.exe") or emrutils.which("ssh")) is None + ): + if check_command_key_format(key_file, ["ppk"]) is False: raise exceptions.WrongPuttyKeyError else: pass -def check_command_key_format(key_file, accepted_file_format=[]): +def check_command_key_format(key_file, accepted_file_format=None): + if accepted_file_format is None: + accepted_file_format = [] if any(key_file.endswith(i) for i in accepted_file_format): return True else: diff --git a/awscli/testutils.py b/awscli/testutils.py index 46d8d4143133..bb13ffc363f8 100644 --- a/awscli/testutils.py +++ b/awscli/testutils.py @@ -49,13 +49,13 @@ from awscli.utils import create_nested_client _LOADER = botocore.loaders.Loader() -INTEG_LOG = logging.getLogger('awscli.tests.integration') +INTEG_LOG = logging.getLogger("awscli.tests.integration") AWS_CMD = None with tempfile.TemporaryDirectory() as tmpdir: - with open(Path(tmpdir) / 'aws-cli-tmp-file', 'w') as f: + with open(Path(tmpdir) / "aws-cli-tmp-file", "w") as f: pass - CASE_INSENSITIVE = (Path(tmpdir) / 'AWS-CLI-TMP-FILE').exists() + CASE_INSENSITIVE = (Path(tmpdir) / "AWS-CLI-TMP-FILE").exists() def skip_if_windows(reason): @@ -70,9 +70,9 @@ def test_some_non_windows_stuff(self): """ def decorator(func): - return unittest.skipIf( - platform.system() not in ['Darwin', 'Linux'], reason - )(func) + return unittest.skipIf(platform.system() not in ["Darwin", "Linux"], reason)( + func + ) return decorator @@ -80,20 +80,20 @@ def decorator(func): def skip_if_case_sensitive(): def decorator(func): return unittest.skipIf( - not CASE_INSENSITIVE, - "This test requires a case-insensitive filesystem." + not CASE_INSENSITIVE, "This test requires a case-insensitive filesystem." )(func) + return decorator def create_clidriver(): driver = awscli.clidriver.create_clidriver() session = driver.session - data_path = session.get_config_variable('data_path').split(os.pathsep) + data_path = session.get_config_variable("data_path").split(os.pathsep) if not data_path: data_path = [] _LOADER.search_paths.extend(data_path) - session.register_component('data_loader', _LOADER) + session.register_component("data_loader", _LOADER) return driver @@ -104,14 +104,14 @@ def get_aws_cmd(): if AWS_CMD is None: # Try /bin/aws repo_root = os.path.dirname(os.path.abspath(awscli.__file__)) - aws_cmd = os.path.join(repo_root, 'bin', 'aws') + aws_cmd = os.path.join(repo_root, "bin", "aws") if not os.path.isfile(aws_cmd): - aws_cmd = _search_path_for_cmd('aws') + aws_cmd = _search_path_for_cmd("aws") if aws_cmd is None: raise ValueError( 'Could not find "aws" executable. Either ' - 'make sure it is on your PATH, or you can ' - 'explicitly set this value using ' + "make sure it is on your PATH, or you can " + "explicitly set this value using " '"set_aws_cmd()"' ) AWS_CMD = aws_cmd @@ -119,7 +119,7 @@ def get_aws_cmd(): def _search_path_for_cmd(cmd_name): - for path in os.environ.get('PATH', '').split(os.pathsep): + for path in os.environ.get("PATH", "").split(os.pathsep): full_cmd_path = os.path.join(path, cmd_name) if os.path.isfile(full_cmd_path): return full_cmd_path @@ -144,9 +144,9 @@ def temporary_file(mode): """ temporary_directory = tempfile.mkdtemp() - basename = 'tmpfile-%s' % str(random_chars(8)) + basename = "tmpfile-%s" % str(random_chars(8)) full_filename = os.path.join(temporary_directory, basename) - open(full_filename, 'w').close() + open(full_filename, "w").close() try: with open(full_filename, mode) as f: yield f @@ -160,19 +160,19 @@ def create_bucket(session, name=None, region=None): :returns: the name of the bucket created """ if not region: - region = 'us-west-2' - client = create_nested_client(session, 's3', region_name=region) + region = "us-west-2" + client = create_nested_client(session, "s3", region_name=region) if name: bucket_name = name else: bucket_name = random_bucket_name() - params = {'Bucket': bucket_name, 'ObjectOwnership': 'ObjectWriter'} - if region != 'us-east-1': - params['CreateBucketConfiguration'] = {'LocationConstraint': region} + params = {"Bucket": bucket_name, "ObjectOwnership": "ObjectWriter"} + if region != "us-east-1": + params["CreateBucketConfiguration"] = {"LocationConstraint": region} try: client.create_bucket(**params) except ClientError as e: - if e.response['Error'].get('Code') == 'BucketAlreadyOwnedByYou': + if e.response["Error"].get("Code") == "BucketAlreadyOwnedByYou": # This can happen in the retried request, when the first one # succeeded on S3 but somehow the response never comes back. # We still got a bucket ready for test anyway. @@ -188,27 +188,27 @@ def create_dir_bucket(session, name=None, location=None): :returns: the name of the bucket created """ if not location: - location = ('us-west-2', 'usw2-az1') + location = ("us-west-2", "usw2-az1") region, az = location - client = create_nested_client(session, 's3', region_name=region) + client = create_nested_client(session, "s3", region_name=region) if name: bucket_name = name else: bucket_name = f"{random_bucket_name()}--{az}--x-s3" params = { - 'Bucket': bucket_name, - 'CreateBucketConfiguration': { - 'Location': {'Type': 'AvailabilityZone', 'Name': az}, - 'Bucket': { - 'Type': 'Directory', - 'DataRedundancy': 'SingleAvailabilityZone', + "Bucket": bucket_name, + "CreateBucketConfiguration": { + "Location": {"Type": "AvailabilityZone", "Name": az}, + "Bucket": { + "Type": "Directory", + "DataRedundancy": "SingleAvailabilityZone", }, }, } try: client.create_bucket(**params) except ClientError as e: - if e.response['Error'].get('Code') == 'BucketAlreadyOwnedByYou': + if e.response["Error"].get("Code") == "BucketAlreadyOwnedByYou": # This can happen in the retried request, when the first one # succeeded on S3 but somehow the response never comes back. # We still got a bucket ready for test anyway. @@ -224,10 +224,10 @@ def random_chars(num_chars): Useful for creating resources with random names. """ - return binascii.hexlify(os.urandom(int(num_chars / 2))).decode('ascii') + return binascii.hexlify(os.urandom(int(num_chars / 2))).decode("ascii") -def random_bucket_name(prefix='awscli-s3integ', num_random=15): +def random_bucket_name(prefix="awscli-s3integ", num_random=15): """Generate a random S3 bucket name. :param prefix: A prefix to use in the bucket name. Useful @@ -250,13 +250,13 @@ class BaseCLIDriverTest(unittest.TestCase): def setUp(self): self.environ = { - 'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'], - 'AWS_DEFAULT_REGION': 'us-east-1', - 'AWS_ACCESS_KEY_ID': 'access_key', - 'AWS_SECRET_ACCESS_KEY': 'secret_key', - 'AWS_CONFIG_FILE': '', + "AWS_DATA_PATH": os.environ["AWS_DATA_PATH"], + "AWS_DEFAULT_REGION": "us-east-1", + "AWS_ACCESS_KEY_ID": "access_key", + "AWS_SECRET_ACCESS_KEY": "secret_key", + "AWS_CONFIG_FILE": "", } - self.environ_patch = mock.patch('os.environ', self.environ) + self.environ_patch = mock.patch("os.environ", self.environ) self.environ_patch.start() self.driver = create_clidriver() self.session = self.driver.session @@ -268,7 +268,7 @@ def tearDown(self): class BaseAWSHelpOutputTest(BaseCLIDriverTest): def setUp(self): super().setUp() - self.renderer_patch = mock.patch('awscli.help.get_renderer') + self.renderer_patch = mock.patch("awscli.help.get_renderer") self.renderer_mock = self.renderer_patch.start() self.renderer = CapturedRenderer() self.renderer_mock.return_value = self.renderer @@ -305,7 +305,7 @@ def assert_not_contains(self, contents): def assert_text_order(self, *args, **kwargs): # First we need to find where the SYNOPSIS section starts. - starting_from = kwargs.pop('starting_from') + starting_from = kwargs.pop("starting_from") args = list(args) contents = self.renderer.rendered_contents self.assertIn(starting_from, contents) @@ -315,23 +315,23 @@ def assert_text_order(self, *args, **kwargs): for i, index in enumerate(arg_indices[1:], 1): if index == -1: self.fail( - 'The string %r was not found in the contents: %s' + "The string %r was not found in the contents: %s" % (args[index], contents) ) if index < previous: self.fail( - 'The string %r came before %r, but was suppose to come ' - 'after it.\n%s' % (args[i], args[i - 1], contents) + "The string %r came before %r, but was suppose to come " + "after it.\n%s" % (args[i], args[i - 1], contents) ) previous = index class CapturedRenderer: def __init__(self): - self.rendered_contents = '' + self.rendered_contents = "" def render(self, contents): - self.rendered_contents = contents.decode('utf-8') + self.rendered_contents = contents.decode("utf-8") class CapturedOutput: @@ -344,18 +344,18 @@ def __init__(self, stdout, stderr): def capture_output(): stderr = StringIO() stdout = StringIO() - with mock.patch('sys.stderr', stderr): - with mock.patch('sys.stdout', stdout): + with mock.patch("sys.stderr", stderr): + with mock.patch("sys.stdout", stdout): yield CapturedOutput(stdout, stderr) @contextlib.contextmanager -def capture_input(input_bytes=b''): +def capture_input(input_bytes=b""): input_data = BytesIO(input_bytes) mock_object = mock.Mock() mock_object.buffer = input_data - with mock.patch('sys.stdin', mock_object): + with mock.patch("sys.stdin", mock_object): yield input_data @@ -371,22 +371,20 @@ def setUp(self): # os.environ so the patched os.environ has this data and # the CLI works. self.environ = { - 'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'], - 'AWS_DEFAULT_REGION': 'us-east-1', - 'AWS_ACCESS_KEY_ID': 'access_key', - 'AWS_SECRET_ACCESS_KEY': 'secret_key', - 'AWS_CONFIG_FILE': '', - 'AWS_SHARED_CREDENTIALS_FILE': '', + "AWS_DATA_PATH": os.environ["AWS_DATA_PATH"], + "AWS_DEFAULT_REGION": "us-east-1", + "AWS_ACCESS_KEY_ID": "access_key", + "AWS_SECRET_ACCESS_KEY": "secret_key", + "AWS_CONFIG_FILE": "", + "AWS_SHARED_CREDENTIALS_FILE": "", } - if os.environ.get('ComSpec'): - self.environ['ComSpec'] = os.environ['ComSpec'] - self.environ_patch = mock.patch('os.environ', self.environ) + if os.environ.get("ComSpec"): + self.environ["ComSpec"] = os.environ["ComSpec"] + self.environ_patch = mock.patch("os.environ", self.environ) self.environ_patch.start() self.http_response = AWSResponse(None, 200, {}, None) self.parsed_response = {} - self.make_request_patch = mock.patch( - 'botocore.endpoint.Endpoint.make_request' - ) + self.make_request_patch = mock.patch("botocore.endpoint.Endpoint.make_request") self.make_request_is_patched = False self.operations_called = [] self.parsed_responses = None @@ -404,7 +402,7 @@ def before_call(self, params, **kwargs): def _store_params(self, params): self.last_request_dict = params - self.last_params = params['body'] + self.last_params = params["body"] def patch_make_request(self): # If you do not stop a previously started patch, @@ -463,10 +461,10 @@ def before_parameter_build(self, params, model, **kwargs): def run_cmd(self, cmd, expected_rc=0): logging.debug("Calling cmd: %s", cmd) self.patch_make_request() - event_emitter = self.driver.session.get_component('event_emitter') - event_emitter.register('before-call', self.before_call) + event_emitter = self.driver.session.get_component("event_emitter") + event_emitter.register("before-call", self.before_call) event_emitter.register_first( - 'before-parameter-build.*.*', self.before_parameter_build + "before-parameter-build.*.*", self.before_parameter_build ) if not isinstance(cmd, list): cmdlist = cmd.split() @@ -495,9 +493,7 @@ def run_cmd(self, cmd, expected_rc=0): class BaseAWSPreviewCommandParamsTest(BaseAWSCommandParamsTest): def setUp(self): - self.preview_patch = mock.patch( - 'awscli.customizations.preview.mark_as_preview' - ) + self.preview_patch = mock.patch("awscli.customizations.preview.mark_as_preview") self.preview_patch.start() super().setUp() @@ -509,16 +505,16 @@ def tearDown(self): class BaseCLIWireResponseTest(unittest.TestCase): def setUp(self): self.environ = { - 'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'], - 'AWS_DEFAULT_REGION': 'us-east-1', - 'AWS_ACCESS_KEY_ID': 'access_key', - 'AWS_SECRET_ACCESS_KEY': 'secret_key', - 'AWS_CONFIG_FILE': '', + "AWS_DATA_PATH": os.environ["AWS_DATA_PATH"], + "AWS_DEFAULT_REGION": "us-east-1", + "AWS_ACCESS_KEY_ID": "access_key", + "AWS_SECRET_ACCESS_KEY": "secret_key", + "AWS_CONFIG_FILE": "", } - self.environ_patch = mock.patch('os.environ', self.environ) + self.environ_patch = mock.patch("os.environ", self.environ) self.environ_patch.start() # TODO: fix this patch when we have a better way to stub out responses - self.send_patch = mock.patch('botocore.endpoint.Endpoint._send') + self.send_patch = mock.patch("botocore.endpoint.Endpoint._send") self.send_is_patched = False self.driver = create_clidriver() @@ -528,7 +524,9 @@ def tearDown(self): self.send_patch.stop() self.send_is_patched = False - def patch_send(self, status_code=200, headers={}, content=b''): + def patch_send(self, status_code=200, headers=None, content=b""): + if headers is None: + headers = {} if self.send_is_patched: self.send_patch.stop() self.send_is_patched = False @@ -567,7 +565,7 @@ def remove_all(self): if os.path.exists(self.rootdir): shutil.rmtree(self.rootdir) - def create_file(self, filename, contents, mtime=None, mode='w'): + def create_file(self, filename, contents, mtime=None, mode="w"): """Creates a file in a tmpdir ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" @@ -606,7 +604,7 @@ def append_file(self, filename, contents): full_path = os.path.join(self.rootdir, filename) if not os.path.isdir(os.path.dirname(full_path)): os.makedirs(os.path.dirname(full_path)) - with open(full_path, 'a') as f: + with open(full_path, "a") as f: f.write(contents) return full_path @@ -688,18 +686,18 @@ def aws( process. This is needed if you plan to stream data into stdin while collecting memory. """ - if platform.system() == 'Windows': + if platform.system() == "Windows": command = _escape_quotes(command) - if 'AWS_TEST_COMMAND' in os.environ: - aws_command = os.environ['AWS_TEST_COMMAND'] + if "AWS_TEST_COMMAND" in os.environ: + aws_command = os.environ["AWS_TEST_COMMAND"] else: - aws_command = 'python %s' % get_aws_cmd() - full_command = '%s %s' % (aws_command, command) + aws_command = "python %s" % get_aws_cmd() + full_command = "%s %s" % (aws_command, command) stdout_encoding = get_stdout_encoding() INTEG_LOG.debug("Running command: %s", full_command) env = os.environ.copy() - if 'AWS_DEFAULT_REGION' not in env: - env['AWS_DEFAULT_REGION'] = "us-east-1" + if "AWS_DEFAULT_REGION" not in env: + env["AWS_DEFAULT_REGION"] = "us-east-1" if env_vars is not None: env = env_vars if input_file is None: @@ -718,7 +716,7 @@ def aws( if not collect_memory: kwargs = {} if input_data: - kwargs = {'input': input_data} + kwargs = {"input": input_data} stdout, stderr = process.communicate(**kwargs) else: stdout, stderr, memory = _wait_and_collect_mem(process) @@ -731,17 +729,17 @@ def aws( def get_stdout_encoding(): - encoding = getattr(sys.__stdout__, 'encoding', None) + encoding = getattr(sys.__stdout__, "encoding", None) if encoding is None: - encoding = 'utf-8' + encoding = "utf-8" return encoding def _wait_and_collect_mem(process): # We only know how to collect memory on mac/linux. - if platform.system() == 'Darwin': + if platform.system() == "Darwin": get_memory = _get_memory_with_ps - elif platform.system() == 'Linux': + elif platform.system() == "Linux": get_memory = _get_memory_with_ps else: raise ValueError( @@ -763,7 +761,7 @@ def _wait_and_collect_mem(process): def _get_memory_with_ps(pid): # It's probably possible to do with proc_pidinfo and ctypes on a Mac, # but we'll do it the easy way with parsing ps output. - command_list = 'ps u -p'.split() + command_list = "ps u -p".split() command_list.append(str(pid)) p = Popen(command_list, stdout=PIPE) stdout = p.communicate()[0] @@ -785,18 +783,18 @@ class BaseS3CLICommand(unittest.TestCase): """ _PUT_HEAD_SHARED_EXTRAS = [ - 'SSECustomerAlgorithm', - 'SSECustomerKey', - 'SSECustomerKeyMD5', - 'RequestPayer', + "SSECustomerAlgorithm", + "SSECustomerKey", + "SSECustomerKeyMD5", + "RequestPayer", ] def setUp(self): self.files = FileCreator() self.session = botocore.session.get_session() self.regions = {} - self.region = 'us-west-2' - self.client = create_nested_client(self.session, 's3', region_name=self.region) + self.region = "us-west-2" + self.client = create_nested_client(self.session, "s3", region_name=self.region) self.extra_setup() def extra_setup(self): @@ -812,18 +810,18 @@ def extra_teardown(self): pass def override_parser(self, **kwargs): - factory = self.session.get_component('response_parser_factory') + factory = self.session.get_component("response_parser_factory") factory.set_parser_defaults(**kwargs) def create_client_for_bucket(self, bucket_name): region = self.regions.get(bucket_name, self.region) - client = create_nested_client(self.session, 's3', region_name=region) + client = create_nested_client(self.session, "s3", region_name=region) return client def assert_key_contents_equal(self, bucket, key, expected_contents): self.wait_until_key_exists(bucket, key) if isinstance(expected_contents, BytesIO): - expected_contents = expected_contents.getvalue().decode('utf-8') + expected_contents = expected_contents.getvalue().decode("utf-8") actual_contents = self.get_key_contents(bucket, key) # The contents can be huge so we try to give helpful error messages # without necessarily printing the actual contents. @@ -863,9 +861,9 @@ def create_dir_bucket(self, name=None, location=None): self.wait_bucket_exists(bucket_name) return bucket_name - def put_object(self, bucket_name, key_name, contents='', extra_args=None): + def put_object(self, bucket_name, key_name, contents="", extra_args=None): client = self.create_client_for_bucket(bucket_name) - call_args = {'Bucket': bucket_name, 'Key': key_name, 'Body': contents} + call_args = {"Bucket": bucket_name, "Key": key_name, "Body": contents} if extra_args is not None: call_args.update(extra_args) response = client.put_object(**call_args) @@ -909,11 +907,11 @@ def delete_bucket(self, bucket_name, attempts=5, delay=5): def remove_all_objects(self, bucket_name): client = self.create_client_for_bucket(bucket_name) - paginator = client.get_paginator('list_objects_v2') + paginator = client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket_name) key_names = [] for page in pages: - key_names += [obj['Key'] for obj in page.get('Contents', [])] + key_names += [obj["Key"] for obj in page.get("Contents", [])] for key_name in key_names: self.delete_key(bucket_name, key_name) @@ -925,17 +923,15 @@ def get_key_contents(self, bucket_name, key_name): self.wait_until_key_exists(bucket_name, key_name) client = self.create_client_for_bucket(bucket_name) response = client.get_object(Bucket=bucket_name, Key=key_name) - return response['Body'].read().decode('utf-8') + return response["Body"].read().decode("utf-8") def wait_bucket_exists(self, bucket_name, min_successes=3): client = self.create_client_for_bucket(bucket_name) - waiter = client.get_waiter('bucket_exists') + waiter = client.get_waiter("bucket_exists") consistency_waiter = ConsistencyWaiter( min_successes=min_successes, delay_initial_poll=True ) - consistency_waiter.wait( - lambda: waiter.wait(Bucket=bucket_name) is None - ) + consistency_waiter.wait(lambda: waiter.wait(Bucket=bucket_name) is None) def bucket_not_exists(self, bucket_name): client = self.create_client_for_bucket(bucket_name) @@ -943,7 +939,7 @@ def bucket_not_exists(self, bucket_name): client.head_bucket(Bucket=bucket_name) return True except ClientError as error: - if error.response.get('Code') == '404': + if error.response.get("Code") == "404": return False raise @@ -967,11 +963,11 @@ def key_not_exists(self, bucket_name, key_name, min_successes=3): def list_buckets(self): response = self.client.list_buckets() - return response['Buckets'] + return response["Buckets"] def content_type_for_key(self, bucket_name, key_name): parsed = self.head_object(bucket_name, key_name) - return parsed['ContentType'] + return parsed["ContentType"] def head_object(self, bucket_name, key_name): client = self.create_client_for_bucket(bucket_name) @@ -1002,10 +998,10 @@ def _wait_for_key( ): client = self.create_client_for_bucket(bucket_name) if exists: - waiter = client.get_waiter('object_exists') + waiter = client.get_waiter("object_exists") else: - waiter = client.get_waiter('object_not_exists') - params = {'Bucket': bucket_name, 'Key': key_name} + waiter = client.get_waiter("object_not_exists") + params = {"Bucket": bucket_name, "Key": key_name} if extra_params is not None: params.update(extra_params) for _ in range(min_successes): @@ -1108,7 +1104,7 @@ def wait(self, check, *args, **kwargs): def _fail_message(self, attempts, successes): format_args = (attempts, successes) - return 'Failed after %s attempts, only had %s successes' % format_args + return "Failed after %s attempts, only had %s successes" % format_args @contextlib.contextmanager diff --git a/tests/unit/customizations/emr/__init__.py b/tests/unit/customizations/emr/__init__.py index 5ee3ffe4c59d..480ff0f65e08 100644 --- a/tests/unit/customizations/emr/__init__.py +++ b/tests/unit/customizations/emr/__init__.py @@ -17,6 +17,7 @@ from awscli.testutils import BaseAWSCommandParamsTest from awscli.testutils import mock + class EMRBaseAWSCommandParamsTest(BaseAWSCommandParamsTest): def setUp(self): @@ -28,19 +29,21 @@ def setUp(self): # Do not write or update the config (~/.aws/config) file self.patcher_update_config = mock.patch( - 'awscli.customizations.emr.configutils.ConfigWriter.update_config') + "awscli.customizations.emr.configutils.ConfigWriter.update_config" + ) self.mock_update_config = self.patcher_update_config.start() def set_configs(self, configs): self.driver.session.get_scoped_config = self.get_scoped_config_mock - self.get_scoped_config_mock.return_value = {'emr': configs} + self.get_scoped_config_mock.return_value = {"emr": configs} def tearDown(self): super(EMRBaseAWSCommandParamsTest, self).tearDown() self.patcher_update_config.stop() - def assert_error_msg(self, cmd, - exception_class_name, error_msg_kwargs={}): + def assert_error_msg(self, cmd, exception_class_name, error_msg_kwargs=None): + if error_msg_kwargs is None: + error_msg_kwargs = {} exception_class = getattr(exceptions, exception_class_name) error_msg = "\n%s\n" % exception_class.fmt.format(**error_msg_kwargs) result = self.run_cmd(cmd, 255)