diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index b8c7f4f7a871..e9882602d100 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -490,7 +490,7 @@ def json_config_schema(self, type): return dict( type='object', additionalProperties=False, - **self._transforms[type]['config_schema']) + **self._transforms[type].get('config_schema', {})) def description(self, type): return self._transforms[type].get('description') diff --git a/sdks/python/apache_beam/yaml/yaml_testing.py b/sdks/python/apache_beam/yaml/yaml_testing.py index e7fbc1d43b6f..ad31afa927e0 100644 --- a/sdks/python/apache_beam/yaml/yaml_testing.py +++ b/sdks/python/apache_beam/yaml/yaml_testing.py @@ -73,12 +73,15 @@ def __str__(self): def run_test(pipeline_spec, test_spec, options=None, fix_failures=False): if isinstance(pipeline_spec, str): - pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader) + pipeline_spec_dict = yaml.load( + pipeline_spec, Loader=yaml_utils.SafeLineLoader) + else: + pipeline_spec_dict = pipeline_spec - pipeline_spec = _preprocess_for_testing(pipeline_spec) + processed_pipeline_spec = _preprocess_for_testing(pipeline_spec_dict) transform_spec, recording_ids = inject_test_tranforms( - pipeline_spec, + processed_pipeline_spec, test_spec, fix_failures) @@ -96,12 +99,18 @@ def run_test(pipeline_spec, test_spec, options=None, fix_failures=False): options = beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata( - pipeline_spec.get('options', {}))) + pipeline_spec_dict.get('options', {}))) + + providers = yaml_provider.merge_providers( + yaml_provider.parse_providers( + '', pipeline_spec_dict.get('providers', [])), + { + 'AssertEqualAndRecord': yaml_provider.as_provider_list( + 'AssertEqualAndRecord', AssertEqualAndRecord) + }) with beam.Pipeline(options=options) as p: - _ = p | yaml_transform.YamlTransform( - transform_spec, - providers={'AssertEqualAndRecord': AssertEqualAndRecord}) + _ = p | yaml_transform.YamlTransform(transform_spec, providers=providers) if fix_failures: fixes = {} diff --git a/sdks/python/apache_beam/yaml/yaml_testing_test.py b/sdks/python/apache_beam/yaml/yaml_testing_test.py index 9fcdafd2ab34..9bb0e64b6db5 100644 --- a/sdks/python/apache_beam/yaml/yaml_testing_test.py +++ b/sdks/python/apache_beam/yaml/yaml_testing_test.py @@ -322,6 +322,40 @@ def test_join_transform_serialization(self): }] }) + def test_toplevel_providers(self): + yaml_testing.run_test( + ''' + pipeline: + type: chain + transforms: + - type: Create + config: + elements: [1, 2, 3] + - type: MyDoubler + providers: + - type: yaml + transforms: + MyDoubler: + body: + type: MapToFields + config: + language: python + fields: + doubled: element * 2 + ''', + { + 'expected_outputs': [{ + 'name': 'MyDoubler', + 'elements': [{ + 'doubled': 2 + }, { + 'doubled': 4 + }, { + 'doubled': 6 + }] + }] + }) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)