diff options
-rw-r--r-- | .gitignore | 138 | ||||
-rw-r--r-- | action.py | 58 | ||||
-rw-r--r-- | actions.yml | 13 | ||||
-rw-r--r-- | command.py | 15 | ||||
-rw-r--r-- | endpoint.py | 105 | ||||
-rw-r--r-- | endpoints.yml | 15 | ||||
-rwxr-xr-x | main.py | 76 | ||||
-rw-r--r-- | misc.py | 6 | ||||
-rw-r--r-- | state.py | 91 | ||||
-rw-r--r-- | transport.py | 80 | ||||
-rw-r--r-- | trigger.py | 124 | ||||
-rw-r--r-- | triggers.yml | 10 |
12 files changed, 731 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a81c8ee --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/action.py b/action.py new file mode 100644 index 0000000..8bb9bf2 --- /dev/null +++ b/action.py @@ -0,0 +1,58 @@ +from typing import Dict +import logging + +import endpoint +import trigger + +class Action: + # TODO: Cooldown, wait fot state change, repeat, etc? + def __init__(self, name: str, config: dict, endpoints: Dict[str, endpoint.Endpoint], triggers: Dict[str, trigger.Trigger]): + self._name = name + self._trigger_cfg = config['trigger'] + self._then_cfg = config['then'] + + self._endpoints = endpoints + self._triggers = triggers + + self._configured_trigger_keys = [] + + self._setup_triggers() + + def _setup_triggers(self): + for trg_list_item in self._trigger_cfg: + if len(trg_list_item.keys()) != 1: + logging.error(f'Action "{self._name}" encountered error while adding trigger "{trg_list_item}"') + raise Exception + + trg_key = list(trg_list_item.keys())[0] + trg_config = trg_list_item[trg_key] + + if not trg_key in self._triggers: + logging.error(f'Action "{self._name}": Trigger "{trg_key}" is not configured.') + raise Exception + + self._configured_trigger_keys.append(trg_key) + self._triggers[trg_key].addInstance(self._name, **trg_config) + logging.debug(f'Action "{self._name}" was registered with "{trg_key}"') + + + def execute(self): + if not all([self._triggers[b].evaluate(self._name) for b in self._configured_trigger_keys]): + logging.debug(f'Action "{self._name}" will not execute. Conditions not met.') + return + + logging.info(f'Executing Action "{self._name}". Conditions are met.') + + for then_item in self._then_cfg: + if len(then_item.keys()) != 1: + logging.error(f'Action "{self._name}" encountered error while executing command "{then_item}"') + raise Exception + + cmd_key = list(then_item.keys())[0] + cmd_config = then_item[cmd_key] + + logging.info(f'Executing command "{cmd_key}"') + endpoint, command = cmd_key.split('.', 1) + self._endpoints[endpoint].executeCommand(command, **cmd_config) + + diff --git a/actions.yml b/actions.yml new file mode 100644 index 0000000..c91759c --- /dev/null +++ b/actions.yml @@ -0,0 +1,13 @@ +--- +send-hello: + trigger: + - conditional: + interval: 30 + when: + - host1.user.jonas > 0 + - True + then: + - host1.notify: + msg: Hello + - host1.notify: + msg: World! diff --git a/command.py b/command.py new file mode 100644 index 0000000..db0d261 --- /dev/null +++ b/command.py @@ -0,0 +1,15 @@ +import transport + +class Command: + def __init__(self, transport: transport.Transport): + raise NotImplemented + + def execute(self, **kwargs): + raise NotImplemented + +class NotifyCommand(Command): + def __init__(self, transport: transport.SshTransport): + self._transport = transport + + def execute(self, msg: str, **kwargs): + self._transport.execHandleStderror(f'notify-send "{msg}"') diff --git a/endpoint.py b/endpoint.py new file mode 100644 index 0000000..7458326 --- /dev/null +++ b/endpoint.py @@ -0,0 +1,105 @@ +import logging +import transport + +def import_class(cl): + d = cl.rfind(".") + classname = cl[d+1:len(cl)] + m = __import__(cl[0:d], globals(), locals(), [classname]) + return getattr(m, classname) + + # Master object +class Endpoint: + def __init__(self, name, config): + transports = {} + commands = {} + states = {} + + # sweet mother of jesus, you are ugly + for tp_key in config['transports']: + tp_cfg = config['transports'][tp_key] + logging.debug(f'loading transport "{tp_key}"') + + # TODO Handle failure + tp_class = import_class(tp_cfg['class']) + del tp_cfg['class'] + + transports[tp_key] = tp_class(**tp_cfg) + + for cmd_key in config['commands']: + cmd_cfg = config['commands'][cmd_key] + logging.debug(f'loading command "{cmd_key}"') + + # TODO Handle failure + cmd_class = import_class(cmd_cfg['class']) + del cmd_cfg['class'] + + if cmd_cfg['transport'] not in transports: + # TODO should we be lenient with errors? + logging.error(f'transport "{cmd_cfg["transport"]}" for command "{cmd_key}" was not found.') + continue + + tp = transports[cmd_cfg['transport']] + del cmd_cfg['transport'] + + commands[cmd_key] = cmd_class(tp, **cmd_cfg) + + # you look familiar + for stt_key in config['states']: + stt_cfg = config['states'][stt_key] + logging.debug(f'loading state "{stt_key}"') + + # TODO Handle failure + stt_class = import_class(stt_cfg['class']) + del stt_cfg['class'] + + if stt_cfg['transport'] not in transports: + # TODO should we be lenient with errors? + logging.error(f'transport "{stt_cfg["transport"]}" for command "{stt_key}" was not found.') + continue + + tp = transports[stt_cfg['transport']] + del stt_cfg['transport'] + + states[stt_key] = stt_class(tp, **stt_cfg) + + # TODO How does the init step look like? Do it here? + # transports prbly need to be connected here + + self._name = name + self._transports = transports + self._commands = commands + self._states = states + + def connectTransport(self): + for k in self._transports: + if self._transports[k].CONNECTION == transport.HOLD: + self._transports[k].connect() + elif self._transports[k].CONNECTION == transport.THROWAWAY: + self._transports[k].check() + else: + logging.error(f'"{self._transports[k].CONNECTION}" is an unknown connection type in transport "{k}"') + + # forces a recollect of all states. should not be needed, states should + # handle that themselves via TTL + # we shouldn't need it + #def collectState(self): + # # TODO we need a interface here + # for k in self._states: + # self._states[k].collect() + + # Format: <state>.<key> + def getState(self, state_key: str): + state, key = state_key.split('.', 1) + + if state not in self._states: + logging.error(f'State "{state}" was not found for "{self._name}"') + return None + + return self._states[state].get(key) + + + def executeCommand(self, cmd: str, **kwargs): + if cmd not in self._commands: + raise Exception(f'Command "{cmd}" is not defined for "{self._name}"') + + self._commands[cmd].execute(**kwargs) diff --git a/endpoints.yml b/endpoints.yml new file mode 100644 index 0000000..fad2160 --- /dev/null +++ b/endpoints.yml @@ -0,0 +1,15 @@ +host1: + transports: + ssh: + class: transport.SshTransport + hostname: 'localhost' + username: 'jonas' + commands: + notify: + class: command.NotifyCommand + transport: ssh + states: + user: + class: state.UserSessionState + transport: ssh + ttl: 30 @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import yaml +import json +import logging +import time + +import transport +import state +import command +import endpoint +import trigger +import misc +import action + +logging.basicConfig(level=logging.DEBUG) +logging.getLogger('paramiko').setLevel(logging.WARNING) + +# Use a TypeDict here +with open('endpoints.yml', 'r') as f: + endpoint_config = yaml.safe_load(f) + +with open('triggers.yml', 'r') as f: + trigger_config = yaml.safe_load(f) + +with open('actions.yml', 'r') as f: + action_config = yaml.safe_load(f) + +endpoints = {} +for ep_key in endpoint_config: + endpoints[ep_key] = endpoint.Endpoint(ep_key, endpoint_config[ep_key]) + +triggers = {} +for trg_key in trigger_config: + cls = misc.import_class(trigger_config[trg_key]['class']) + del trigger_config[trg_key]['class'] + + if cls.NEEDS_CONTEXT: + triggers[trg_key] = cls(endpoints, **trigger_config[trg_key]) + else: + triggers[trg_key] = cls(**trigger_config[trg_key]) + +actions = {} +for act_key in action_config: + actions[act_key] = action.Action(act_key, action_config[act_key], endpoints, triggers) + + +# TODO should we do that in Endpoint.__init__()? +for k in endpoints: + endpoints[k].connectTransport() + +for act_key in action_config: + actions[act_key].execute() + +for act_key in action_config: + actions[act_key].execute() + + +#print(endpoints['host1'].getState('user.jonas')) +#print(endpoints['host1'].getState('user.jonas')) +# +#time.sleep(31) +#print(endpoints['host1'].getState('user.jonas')) + +#endpoints['host1'].executeCommand('notify', msg='moinsen') + +#tr = transport.SshTransport('localhost', username='jonas') +#tr.connect() +# +#noti = command.NotifyCommand(tr) +#noti.execute('OwO') +# +#sta = state.UserState(tr) +#sta.collect() +# +#tr.disconnect() @@ -0,0 +1,6 @@ + +def import_class(cl): + d = cl.rfind(".") + classname = cl[d+1:len(cl)] + m = __import__(cl[0:d], globals(), locals(), [classname]) + return getattr(m, classname) diff --git a/state.py b/state.py new file mode 100644 index 0000000..b17db99 --- /dev/null +++ b/state.py @@ -0,0 +1,91 @@ +import time +import logging + +import transport + +''' +Implementations of State: + +MUST implement: + _collect(self) + +CAN implement: + _get(self, key: str) + +SHOULDNT implement: + get(self, key) + collect(self) + +Data is stored in self._data as a dictionary. +By default, _get(key) retrieves the returns self._data[key]. +This behaviour can be overridden by implementing a own _get(). + +If using the default _get(), _collect() has to store data in +the self._data dictionary. If an own _get() is implemented, +this does not need to be the case. +''' +class State: + def __init__(self, transport: transport.Transport, ttl: int = 30): + self._transport = transport + self._ttl = ttl + + self._data = {} + self._last_collected = 0 + + def _collect(self): + raise NotImplemented + + def _get(self, key: str): + if key not in self._data: + logging.error(f'Data key {key} was not found.') + return None + + return self._data[key] + + def _shouldCollect(self): + return time.time() - self._last_collected > self._ttl + + def get(self, key: str): + if self._shouldCollect(): + logging.debug(f'Cached value for "{key}" is too old. refreshing.') + self.collect() + else: + logging.debug(f'Using cached value for "{key}".') + + + return self._get(key) + + # Force datacollection. not really needed + def collect(self): + self._collect() + self._last_collected = time.time() + +class UserSessionState(State): + def __init__(self, transport: transport.SshTransport, ttl: int = 30): + super().__init__(transport, ttl) + + # this is not needed. it's here to shut up pylint + self._transport = transport + + def _get(self, key: str): + if key not in self._data: + return 0 + + return self._data[key] + + def _collect(self): + data = self._transport.execHandleStderror('who').decode('utf-8') + # TODO error handling + lines = data.split('\n') + + self._data = {} + + for l in lines: + name, _ = l.split(' ', 1) + + logging.debug(f'Found user session {name}') + + if name not in self._data: + self._data[name] = 0 + + self._data[name] += 1 diff --git a/transport.py b/transport.py new file mode 100644 index 0000000..91c5029 --- /dev/null +++ b/transport.py @@ -0,0 +1,80 @@ +import paramiko + +HOLD = 1 +THROWAWAY = 2 + +# Abstract classes to implement +class Transport: + NAME = 'BASE' + CONNECTION = HOLD + #CONNECTION = THROWAWAY + + def __init__(self): + self._connected = False + raise NotImplemented + + # Connects to the transport, if CONNECTION == HOLD + def connect(self): + raise NotImplemented + + # disconnects to the transport, if CONNECTION == HOLD + def disconnect(self): + raise NotImplemented + + # validate that the transport works, if CONNECTION == THROWAWAY + def check(self): + raise NotImplemented + + def isConnected(self) -> bool: + return self._connected + +class SshTransport(Transport): + NAME='SSH' + CONNECTION=HOLD + + def __init__(self, hostname: str, port=22, username='root', password = None, id_file = None): + self._hostname = hostname + self._port = port + self._username = username + self._password = password + self._id_file = id_file + + self._connected = False + self._client = None + + def connect(self): + self._client = paramiko.SSHClient() + + # TODO known hosts + self._client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy) + self._client.connect(self._hostname, port=self._port, username=self._username, password=self._password, key_filename=None, allow_agent=True) + + self._connected = True + + # return(str: stdout, str: stderr, int: retcode) + def exec(self, command: str): + if not self._connected: + raise Exception('Not connected') + + output = self._client.exec_command(command, timeout=5) + + retcode = output[1].channel.recv_exit_status() + return (output[1].read().strip(), output[2].read().strip(), retcode) + + def execHandleStderror(self, command: str): + out = self.exec(command) + + if out[2] != 0: + raise Exception(f'Command returned error {out[2]}: {out[1]}') + + return out[0] + + def readFile(self, path: str): + return self.execHandleStderror(f'cat "{path}"') + + def disconnect(self): + if self._connected: + self._client.close() + + self._connected = False + self._client = None diff --git a/trigger.py b/trigger.py new file mode 100644 index 0000000..7d2fdfd --- /dev/null +++ b/trigger.py @@ -0,0 +1,124 @@ +from typing import Dict +from pyparsing import alphanums, alphas, printables, pyparsing_common, pyparsing_common, Word, infix_notation, CaselessKeyword, opAssoc, ParserElement +import logging +import time + +import endpoint +import misc + +''' +Implementations of Trigger: + +MUST implement: + _evaluate(self, action: str) -> bool + evaluates the instace for action given by 'action'. + Provided configuration is stored in self._instances[action]['args'] + +CAN implement: + _addInstance(self, action: str) + Called afer 'action' was added. + +SHOULDNT implement: + evaluate(self, action: str) -> bool + Only calls _evaluate(), if no check was performed in configured interval, + otherwise returns cached result + addInstance(self, action:str, interval=30, **kwargs) +''' +class Trigger: + NEEDS_CONTEXT = False + + @staticmethod + def create(classname: str, **kwargs): + return misc.import_class(classname)(**kwargs) + + def __init__(self): + self._instances = {} + + def _addInstance(self, action: str): + pass + + def addInstance(self, action: str, interval: int=30, **kwargs): + self._instances[action] = {'lastupdate':0,'interval':interval,'last':False,'args':kwargs} + self._addInstance(action) + logging.debug(f'Trigger: Action "{action}" registered.') + + def _evaluate(self, action: str) -> bool: + raise NotImplemented + + def _shouldReevaluate(self, action: str) -> bool: + return time.time() - self._instances[action]['lastupdate'] > self._instances[action]['interval'] + + def evaluate(self, action: str) -> bool: + if action not in self._instances: + logging.error(f'Trigger: Action "{action}" was not found. Evaluating to False.') + return False + + if self._shouldReevaluate(action): + logging.debug(f'Re-evaluating trigger condition for action "{action}"') + result = self._evaluate(action) + + self._instances[action]['last'] = result + self._instances[action]['lastupdate'] = time.time() + return result + + return self._instances[action]['last'] + +''' +```yaml +conditional: + class: trigger.Conditional +--- +- conditional: + interval: 30 + when: + - host1.user.bob > 0 +``` +''' +class ConditionalTrigger(Trigger): + NEEDS_CONTEXT = True + + def __init__(self, endpoints: Dict[str, endpoint.Endpoint]): + super().__init__() + + self._endpoints = endpoints + self._setup_parser() + + def _setup_parser(self): + ParserElement.enable_packrat() + + boolean = CaselessKeyword('True').setParseAction(lambda x: True) | CaselessKeyword('False').setParseAction(lambda x: False) + integer = pyparsing_common.integer + variable = Word(alphanums + '.').setParseAction(self._parseVariable) + operand = boolean | integer | variable + + self._parser = infix_notation( + operand, + [ + ('not', 1, opAssoc.RIGHT, lambda a: not a[0][1]), + ('and', 2, opAssoc.LEFT, lambda a: a[0][0] and a[0][2]), + ('or', 2, opAssoc.LEFT, lambda a: a[0][0] or a[0][2]), + ('==', 2, opAssoc.LEFT, lambda a: a[0][0] == a[0][2]), + ('>', 2, opAssoc.LEFT, lambda a: a[0][0] > a[0][2]), + ('>=', 2, opAssoc.LEFT, lambda a: a[0][0] >= a[0][2]), + ('<', 2, opAssoc.LEFT, lambda a: a[0][0] < a[0][2]), + ('<=', 2, opAssoc.LEFT, lambda a: a[0][0] <= a[0][2]), + ('+', 2, opAssoc.LEFT, lambda a: a[0][0] + a[0][2]), + ('-', 2, opAssoc.LEFT, lambda a: a[0][0] - a[0][2]), + ('*', 2, opAssoc.LEFT, lambda a: a[0][0] * a[0][2]), + ('/', 2, opAssoc.LEFT, lambda a: a[0][0] / a[0][2]), + ] + ) + + def _parseVariable(self, var): + logging.debug(f'Looking up variable "{var[0]}"') + endpoint, key = var[0].split('.',1) + + if not endpoint in self._endpoints: + logging.error(f'Parser: Endpoint "{endpoint}" not found') + return None + + return self._endpoints[endpoint].getState(key) + + def _evaluate(self, action: str) -> bool: + return all(self._parser.parse_string(str(s)) for s in self._instances[action]['args']['when']) + diff --git a/triggers.yml b/triggers.yml new file mode 100644 index 0000000..9595f7e --- /dev/null +++ b/triggers.yml @@ -0,0 +1,10 @@ +conditional: + class: trigger.ConditionalTrigger + +#mqtt: +# class: trigger.Mqtt +# server: asdf +# user: OwO +# +#timer: +# class: trigger.Timer |