From 6eb5ee144d5e730616f13d3d14129f0205282d90 Mon Sep 17 00:00:00 2001 From: Jonas Gunz Date: Tue, 7 Nov 2023 00:02:48 +0100 Subject: aggregator refactor --- aggregator/dwd_icon.py | 29 +++++++++++++++-------------- run.py | 28 ++++++++++++++++++---------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/aggregator/dwd_icon.py b/aggregator/dwd_icon.py index 60c8433..b8a4470 100755 --- a/aggregator/dwd_icon.py +++ b/aggregator/dwd_icon.py @@ -42,15 +42,16 @@ def unpack_bz2(dest): if res.returncode != 0: print(f'There was an error unpacking {dest}:', res.stderr) -def download_dwd_gribs(config, date, run, target): - model = config['model'] - model_long = config['model_long'] - - misc.create_output_dir(config['output']) +def download_dwd_gribs( + date, run, target, output, model, steps, model_long, + pressure_level_parameters, parameter_caps_in_filename, + single_level_parameters +): + misc.create_output_dir(output) to_download = [] - for step in config['steps']: + for step in steps: step_str = f'{step:03d}' for parameter in config['pressure_level_parameters']: @@ -62,12 +63,12 @@ def download_dwd_gribs(config, date, run, target): to_download.append((URL, os.path.join(config['output'], filename))) - for parameter in config['single_level_parameters']: - parameter2 = parameter.upper() if config['parameter_caps_in_filename'] else parameter + for parameter in single_level_parameters: + parameter2 = parameter.upper() if parameter_caps_in_filename else parameter filename = f'{model_long}_regular-lat-lon_single-level_{date}{run}_{step_str}_{parameter2}.grib2.bz2' URL = f'{BASE}/{model}/grib/{run}/{parameter}/{filename}' - to_download.append((URL, os.path.join(config['output'], filename))) + to_download.append((URL, os.path.join(output, filename))) for _ in ThreadPool(cpu_count()).imap_unordered(download_url, to_download): @@ -88,18 +89,18 @@ def download_dwd_gribs(config, date, run, target): if res.returncode != 0: print('rm failed with: ', res.stderr) -def load_data(name, config): +def load_data(name, output, **kwargs): run, date = get_current_run() - target = os.path.join(config['output'], f'{name}_{date}_{run}.grib2') + target = os.path.join(output, f'{name}_{date}_{run}.grib2') if not os.path.exists(target): - download_dwd_gribs(config, date, run, target) + download_dwd_gribs(date, run, target, output, **kwargs) else: print(f'{target} alreasy exists. Using the cached version.') return xr.load_dataset(target, engine='cfgrib') -config = { +debug_config = { 'output':'dwd_icon-eu', 'model':'icon-eu', 'model_long':'icon-eu_europe', @@ -122,5 +123,5 @@ config = { } if __name__ == '__main__': - load_data('test_icon_eu', config) + load_data('test_icon_eu', **debug_config) diff --git a/run.py b/run.py index d0a004a..2220db3 100755 --- a/run.py +++ b/run.py @@ -9,6 +9,20 @@ from matplotlib.colors import LinearSegmentedColormap from metpy.units import units +def create_aggregators(cfg): + ret = {} + for aggregator in cfg: + aggconf = cfg[aggregator] + classname = aggconf['module'] + del aggconf['module'] + + ret[aggregator] = {} + ret[aggregator]['module'] = __import__(classname, fromlist=[None]) + ret[aggregator]['config'] = aggconf + ret[aggregator]['config']['name'] = aggregator + + return ret + # Define custom gpm and gpdm units. The default gpm in metpy is aliased to meter. # We need the correct definition units.define('_gpm = 9.80665 * J/kg') @@ -43,6 +57,8 @@ conf = None with open(FILE, 'r') as f: conf = yaml.safe_load(f) +aggregators = create_aggregators(conf['aggregator']) + index = [] for plotter in conf['plotter']: @@ -50,17 +66,9 @@ for plotter in conf['plotter']: del plotter['module'] if 'aggregator' in plotter: - aggname = plotter['aggregator'] + agg = aggregators[plotter['aggregator']] + plotter['data'] = agg['module'].load_data(**agg['config']) del plotter['aggregator'] - aggconf = conf['aggregator'][aggname] - classname = aggconf['module'] - # We should prbly not delete it like in the plotter, since it is not a deepcopy - # and may be used again. - - agg = __import__(classname, fromlist=[None]) - - # TODO: figure out a way to use **aggconf instead. - plotter['data'] = agg.load_data(aggname, aggconf) mod = __import__(modname, fromlist=[None]) index.extend(mod.run(**plotter)) -- cgit v1.2.3