aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jonas Gunz <himself@jonasgunz.de> 2023-02-07 01:26:41 +0100
committerGravatar Jonas Gunz <himself@jonasgunz.de> 2023-02-07 01:26:41 +0100
commitc2109e5561299b2a120d1a669d58f6147ca40fb1 (patch)
tree1d8193f8765498bfbd209a4bfb228559cd994294
downloadautomato-c2109e5561299b2a120d1a669d58f6147ca40fb1.tar.gz
first commit
-rw-r--r--.gitignore138
-rw-r--r--action.py58
-rw-r--r--actions.yml13
-rw-r--r--command.py15
-rw-r--r--endpoint.py105
-rw-r--r--endpoints.yml15
-rwxr-xr-xmain.py76
-rw-r--r--misc.py6
-rw-r--r--state.py91
-rw-r--r--transport.py80
-rw-r--r--trigger.py124
-rw-r--r--triggers.yml10
12 files changed, 731 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a81c8ee
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,138 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
diff --git a/action.py b/action.py
new file mode 100644
index 0000000..8bb9bf2
--- /dev/null
+++ b/action.py
@@ -0,0 +1,58 @@
+from typing import Dict
+import logging
+
+import endpoint
+import trigger
+
+class Action:
+ # TODO: Cooldown, wait fot state change, repeat, etc?
+ def __init__(self, name: str, config: dict, endpoints: Dict[str, endpoint.Endpoint], triggers: Dict[str, trigger.Trigger]):
+ self._name = name
+ self._trigger_cfg = config['trigger']
+ self._then_cfg = config['then']
+
+ self._endpoints = endpoints
+ self._triggers = triggers
+
+ self._configured_trigger_keys = []
+
+ self._setup_triggers()
+
+ def _setup_triggers(self):
+ for trg_list_item in self._trigger_cfg:
+ if len(trg_list_item.keys()) != 1:
+ logging.error(f'Action "{self._name}" encountered error while adding trigger "{trg_list_item}"')
+ raise Exception
+
+ trg_key = list(trg_list_item.keys())[0]
+ trg_config = trg_list_item[trg_key]
+
+ if not trg_key in self._triggers:
+ logging.error(f'Action "{self._name}": Trigger "{trg_key}" is not configured.')
+ raise Exception
+
+ self._configured_trigger_keys.append(trg_key)
+ self._triggers[trg_key].addInstance(self._name, **trg_config)
+ logging.debug(f'Action "{self._name}" was registered with "{trg_key}"')
+
+
+ def execute(self):
+ if not all([self._triggers[b].evaluate(self._name) for b in self._configured_trigger_keys]):
+ logging.debug(f'Action "{self._name}" will not execute. Conditions not met.')
+ return
+
+ logging.info(f'Executing Action "{self._name}". Conditions are met.')
+
+ for then_item in self._then_cfg:
+ if len(then_item.keys()) != 1:
+ logging.error(f'Action "{self._name}" encountered error while executing command "{then_item}"')
+ raise Exception
+
+ cmd_key = list(then_item.keys())[0]
+ cmd_config = then_item[cmd_key]
+
+ logging.info(f'Executing command "{cmd_key}"')
+ endpoint, command = cmd_key.split('.', 1)
+ self._endpoints[endpoint].executeCommand(command, **cmd_config)
+
+
diff --git a/actions.yml b/actions.yml
new file mode 100644
index 0000000..c91759c
--- /dev/null
+++ b/actions.yml
@@ -0,0 +1,13 @@
+---
+send-hello:
+ trigger:
+ - conditional:
+ interval: 30
+ when:
+ - host1.user.jonas > 0
+ - True
+ then:
+ - host1.notify:
+ msg: Hello
+ - host1.notify:
+ msg: World!
diff --git a/command.py b/command.py
new file mode 100644
index 0000000..db0d261
--- /dev/null
+++ b/command.py
@@ -0,0 +1,15 @@
+import transport
+
+class Command:
+ def __init__(self, transport: transport.Transport):
+ raise NotImplemented
+
+ def execute(self, **kwargs):
+ raise NotImplemented
+
+class NotifyCommand(Command):
+ def __init__(self, transport: transport.SshTransport):
+ self._transport = transport
+
+ def execute(self, msg: str, **kwargs):
+ self._transport.execHandleStderror(f'notify-send "{msg}"')
diff --git a/endpoint.py b/endpoint.py
new file mode 100644
index 0000000..7458326
--- /dev/null
+++ b/endpoint.py
@@ -0,0 +1,105 @@
+import logging
+import transport
+
+def import_class(cl):
+ d = cl.rfind(".")
+ classname = cl[d+1:len(cl)]
+ m = __import__(cl[0:d], globals(), locals(), [classname])
+ return getattr(m, classname)
+
+ # Master object
+class Endpoint:
+ def __init__(self, name, config):
+ transports = {}
+ commands = {}
+ states = {}
+
+ # sweet mother of jesus, you are ugly
+ for tp_key in config['transports']:
+ tp_cfg = config['transports'][tp_key]
+ logging.debug(f'loading transport "{tp_key}"')
+
+ # TODO Handle failure
+ tp_class = import_class(tp_cfg['class'])
+ del tp_cfg['class']
+
+ transports[tp_key] = tp_class(**tp_cfg)
+
+ for cmd_key in config['commands']:
+ cmd_cfg = config['commands'][cmd_key]
+ logging.debug(f'loading command "{cmd_key}"')
+
+ # TODO Handle failure
+ cmd_class = import_class(cmd_cfg['class'])
+ del cmd_cfg['class']
+
+ if cmd_cfg['transport'] not in transports:
+ # TODO should we be lenient with errors?
+ logging.error(f'transport "{cmd_cfg["transport"]}" for command "{cmd_key}" was not found.')
+ continue
+
+ tp = transports[cmd_cfg['transport']]
+ del cmd_cfg['transport']
+
+ commands[cmd_key] = cmd_class(tp, **cmd_cfg)
+
+ # you look familiar
+ for stt_key in config['states']:
+ stt_cfg = config['states'][stt_key]
+ logging.debug(f'loading state "{stt_key}"')
+
+ # TODO Handle failure
+ stt_class = import_class(stt_cfg['class'])
+ del stt_cfg['class']
+
+ if stt_cfg['transport'] not in transports:
+ # TODO should we be lenient with errors?
+ logging.error(f'transport "{stt_cfg["transport"]}" for command "{stt_key}" was not found.')
+ continue
+
+ tp = transports[stt_cfg['transport']]
+ del stt_cfg['transport']
+
+ states[stt_key] = stt_class(tp, **stt_cfg)
+
+ # TODO How does the init step look like? Do it here?
+ # transports prbly need to be connected here
+
+ self._name = name
+ self._transports = transports
+ self._commands = commands
+ self._states = states
+
+ def connectTransport(self):
+ for k in self._transports:
+ if self._transports[k].CONNECTION == transport.HOLD:
+ self._transports[k].connect()
+ elif self._transports[k].CONNECTION == transport.THROWAWAY:
+ self._transports[k].check()
+ else:
+ logging.error(f'"{self._transports[k].CONNECTION}" is an unknown connection type in transport "{k}"')
+
+ # forces a recollect of all states. should not be needed, states should
+ # handle that themselves via TTL
+ # we shouldn't need it
+ #def collectState(self):
+ # # TODO we need a interface here
+ # for k in self._states:
+ # self._states[k].collect()
+
+ # Format: <state>.<key>
+ def getState(self, state_key: str):
+ state, key = state_key.split('.', 1)
+
+ if state not in self._states:
+ logging.error(f'State "{state}" was not found for "{self._name}"')
+ return None
+
+ return self._states[state].get(key)
+
+
+ def executeCommand(self, cmd: str, **kwargs):
+ if cmd not in self._commands:
+ raise Exception(f'Command "{cmd}" is not defined for "{self._name}"')
+
+ self._commands[cmd].execute(**kwargs)
diff --git a/endpoints.yml b/endpoints.yml
new file mode 100644
index 0000000..fad2160
--- /dev/null
+++ b/endpoints.yml
@@ -0,0 +1,15 @@
+host1:
+ transports:
+ ssh:
+ class: transport.SshTransport
+ hostname: 'localhost'
+ username: 'jonas'
+ commands:
+ notify:
+ class: command.NotifyCommand
+ transport: ssh
+ states:
+ user:
+ class: state.UserSessionState
+ transport: ssh
+ ttl: 30
diff --git a/main.py b/main.py
new file mode 100755
index 0000000..b2ffe42
--- /dev/null
+++ b/main.py
@@ -0,0 +1,76 @@
+#!/usr/bin/env python3
+
+import yaml
+import json
+import logging
+import time
+
+import transport
+import state
+import command
+import endpoint
+import trigger
+import misc
+import action
+
+logging.basicConfig(level=logging.DEBUG)
+logging.getLogger('paramiko').setLevel(logging.WARNING)
+
+# Use a TypeDict here
+with open('endpoints.yml', 'r') as f:
+ endpoint_config = yaml.safe_load(f)
+
+with open('triggers.yml', 'r') as f:
+ trigger_config = yaml.safe_load(f)
+
+with open('actions.yml', 'r') as f:
+ action_config = yaml.safe_load(f)
+
+endpoints = {}
+for ep_key in endpoint_config:
+ endpoints[ep_key] = endpoint.Endpoint(ep_key, endpoint_config[ep_key])
+
+triggers = {}
+for trg_key in trigger_config:
+ cls = misc.import_class(trigger_config[trg_key]['class'])
+ del trigger_config[trg_key]['class']
+
+ if cls.NEEDS_CONTEXT:
+ triggers[trg_key] = cls(endpoints, **trigger_config[trg_key])
+ else:
+ triggers[trg_key] = cls(**trigger_config[trg_key])
+
+actions = {}
+for act_key in action_config:
+ actions[act_key] = action.Action(act_key, action_config[act_key], endpoints, triggers)
+
+
+# TODO should we do that in Endpoint.__init__()?
+for k in endpoints:
+ endpoints[k].connectTransport()
+
+for act_key in action_config:
+ actions[act_key].execute()
+
+for act_key in action_config:
+ actions[act_key].execute()
+
+
+#print(endpoints['host1'].getState('user.jonas'))
+#print(endpoints['host1'].getState('user.jonas'))
+#
+#time.sleep(31)
+#print(endpoints['host1'].getState('user.jonas'))
+
+#endpoints['host1'].executeCommand('notify', msg='moinsen')
+
+#tr = transport.SshTransport('localhost', username='jonas')
+#tr.connect()
+#
+#noti = command.NotifyCommand(tr)
+#noti.execute('OwO')
+#
+#sta = state.UserState(tr)
+#sta.collect()
+#
+#tr.disconnect()
diff --git a/misc.py b/misc.py
new file mode 100644
index 0000000..99fad74
--- /dev/null
+++ b/misc.py
@@ -0,0 +1,6 @@
+
+def import_class(cl):
+ d = cl.rfind(".")
+ classname = cl[d+1:len(cl)]
+ m = __import__(cl[0:d], globals(), locals(), [classname])
+ return getattr(m, classname)
diff --git a/state.py b/state.py
new file mode 100644
index 0000000..b17db99
--- /dev/null
+++ b/state.py
@@ -0,0 +1,91 @@
+import time
+import logging
+
+import transport
+
+'''
+Implementations of State:
+
+MUST implement:
+ _collect(self)
+
+CAN implement:
+ _get(self, key: str)
+
+SHOULDNT implement:
+ get(self, key)
+ collect(self)
+
+Data is stored in self._data as a dictionary.
+By default, _get(key) retrieves the returns self._data[key].
+This behaviour can be overridden by implementing a own _get().
+
+If using the default _get(), _collect() has to store data in
+the self._data dictionary. If an own _get() is implemented,
+this does not need to be the case.
+'''
+class State:
+ def __init__(self, transport: transport.Transport, ttl: int = 30):
+ self._transport = transport
+ self._ttl = ttl
+
+ self._data = {}
+ self._last_collected = 0
+
+ def _collect(self):
+ raise NotImplemented
+
+ def _get(self, key: str):
+ if key not in self._data:
+ logging.error(f'Data key {key} was not found.')
+ return None
+
+ return self._data[key]
+
+ def _shouldCollect(self):
+ return time.time() - self._last_collected > self._ttl
+
+ def get(self, key: str):
+ if self._shouldCollect():
+ logging.debug(f'Cached value for "{key}" is too old. refreshing.')
+ self.collect()
+ else:
+ logging.debug(f'Using cached value for "{key}".')
+
+
+ return self._get(key)
+
+ # Force datacollection. not really needed
+ def collect(self):
+ self._collect()
+ self._last_collected = time.time()
+
+class UserSessionState(State):
+ def __init__(self, transport: transport.SshTransport, ttl: int = 30):
+ super().__init__(transport, ttl)
+
+ # this is not needed. it's here to shut up pylint
+ self._transport = transport
+
+ def _get(self, key: str):
+ if key not in self._data:
+ return 0
+
+ return self._data[key]
+
+ def _collect(self):
+ data = self._transport.execHandleStderror('who').decode('utf-8')
+ # TODO error handling
+ lines = data.split('\n')
+
+ self._data = {}
+
+ for l in lines:
+ name, _ = l.split(' ', 1)
+
+ logging.debug(f'Found user session {name}')
+
+ if name not in self._data:
+ self._data[name] = 0
+
+ self._data[name] += 1
diff --git a/transport.py b/transport.py
new file mode 100644
index 0000000..91c5029
--- /dev/null
+++ b/transport.py
@@ -0,0 +1,80 @@
+import paramiko
+
+HOLD = 1
+THROWAWAY = 2
+
+# Abstract classes to implement
+class Transport:
+ NAME = 'BASE'
+ CONNECTION = HOLD
+ #CONNECTION = THROWAWAY
+
+ def __init__(self):
+ self._connected = False
+ raise NotImplemented
+
+ # Connects to the transport, if CONNECTION == HOLD
+ def connect(self):
+ raise NotImplemented
+
+ # disconnects to the transport, if CONNECTION == HOLD
+ def disconnect(self):
+ raise NotImplemented
+
+ # validate that the transport works, if CONNECTION == THROWAWAY
+ def check(self):
+ raise NotImplemented
+
+ def isConnected(self) -> bool:
+ return self._connected
+
+class SshTransport(Transport):
+ NAME='SSH'
+ CONNECTION=HOLD
+
+ def __init__(self, hostname: str, port=22, username='root', password = None, id_file = None):
+ self._hostname = hostname
+ self._port = port
+ self._username = username
+ self._password = password
+ self._id_file = id_file
+
+ self._connected = False
+ self._client = None
+
+ def connect(self):
+ self._client = paramiko.SSHClient()
+
+ # TODO known hosts
+ self._client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy)
+ self._client.connect(self._hostname, port=self._port, username=self._username, password=self._password, key_filename=None, allow_agent=True)
+
+ self._connected = True
+
+ # return(str: stdout, str: stderr, int: retcode)
+ def exec(self, command: str):
+ if not self._connected:
+ raise Exception('Not connected')
+
+ output = self._client.exec_command(command, timeout=5)
+
+ retcode = output[1].channel.recv_exit_status()
+ return (output[1].read().strip(), output[2].read().strip(), retcode)
+
+ def execHandleStderror(self, command: str):
+ out = self.exec(command)
+
+ if out[2] != 0:
+ raise Exception(f'Command returned error {out[2]}: {out[1]}')
+
+ return out[0]
+
+ def readFile(self, path: str):
+ return self.execHandleStderror(f'cat "{path}"')
+
+ def disconnect(self):
+ if self._connected:
+ self._client.close()
+
+ self._connected = False
+ self._client = None
diff --git a/trigger.py b/trigger.py
new file mode 100644
index 0000000..7d2fdfd
--- /dev/null
+++ b/trigger.py
@@ -0,0 +1,124 @@
+from typing import Dict
+from pyparsing import alphanums, alphas, printables, pyparsing_common, pyparsing_common, Word, infix_notation, CaselessKeyword, opAssoc, ParserElement
+import logging
+import time
+
+import endpoint
+import misc
+
+'''
+Implementations of Trigger:
+
+MUST implement:
+ _evaluate(self, action: str) -> bool
+ evaluates the instace for action given by 'action'.
+ Provided configuration is stored in self._instances[action]['args']
+
+CAN implement:
+ _addInstance(self, action: str)
+ Called afer 'action' was added.
+
+SHOULDNT implement:
+ evaluate(self, action: str) -> bool
+ Only calls _evaluate(), if no check was performed in configured interval,
+ otherwise returns cached result
+ addInstance(self, action:str, interval=30, **kwargs)
+'''
+class Trigger:
+ NEEDS_CONTEXT = False
+
+ @staticmethod
+ def create(classname: str, **kwargs):
+ return misc.import_class(classname)(**kwargs)
+
+ def __init__(self):
+ self._instances = {}
+
+ def _addInstance(self, action: str):
+ pass
+
+ def addInstance(self, action: str, interval: int=30, **kwargs):
+ self._instances[action] = {'lastupdate':0,'interval':interval,'last':False,'args':kwargs}
+ self._addInstance(action)
+ logging.debug(f'Trigger: Action "{action}" registered.')
+
+ def _evaluate(self, action: str) -> bool:
+ raise NotImplemented
+
+ def _shouldReevaluate(self, action: str) -> bool:
+ return time.time() - self._instances[action]['lastupdate'] > self._instances[action]['interval']
+
+ def evaluate(self, action: str) -> bool:
+ if action not in self._instances:
+ logging.error(f'Trigger: Action "{action}" was not found. Evaluating to False.')
+ return False
+
+ if self._shouldReevaluate(action):
+ logging.debug(f'Re-evaluating trigger condition for action "{action}"')
+ result = self._evaluate(action)
+
+ self._instances[action]['last'] = result
+ self._instances[action]['lastupdate'] = time.time()
+ return result
+
+ return self._instances[action]['last']
+
+'''
+```yaml
+conditional:
+ class: trigger.Conditional
+---
+- conditional:
+ interval: 30
+ when:
+ - host1.user.bob > 0
+```
+'''
+class ConditionalTrigger(Trigger):
+ NEEDS_CONTEXT = True
+
+ def __init__(self, endpoints: Dict[str, endpoint.Endpoint]):
+ super().__init__()
+
+ self._endpoints = endpoints
+ self._setup_parser()
+
+ def _setup_parser(self):
+ ParserElement.enable_packrat()
+
+ boolean = CaselessKeyword('True').setParseAction(lambda x: True) | CaselessKeyword('False').setParseAction(lambda x: False)
+ integer = pyparsing_common.integer
+ variable = Word(alphanums + '.').setParseAction(self._parseVariable)
+ operand = boolean | integer | variable
+
+ self._parser = infix_notation(
+ operand,
+ [
+ ('not', 1, opAssoc.RIGHT, lambda a: not a[0][1]),
+ ('and', 2, opAssoc.LEFT, lambda a: a[0][0] and a[0][2]),
+ ('or', 2, opAssoc.LEFT, lambda a: a[0][0] or a[0][2]),
+ ('==', 2, opAssoc.LEFT, lambda a: a[0][0] == a[0][2]),
+ ('>', 2, opAssoc.LEFT, lambda a: a[0][0] > a[0][2]),
+ ('>=', 2, opAssoc.LEFT, lambda a: a[0][0] >= a[0][2]),
+ ('<', 2, opAssoc.LEFT, lambda a: a[0][0] < a[0][2]),
+ ('<=', 2, opAssoc.LEFT, lambda a: a[0][0] <= a[0][2]),
+ ('+', 2, opAssoc.LEFT, lambda a: a[0][0] + a[0][2]),
+ ('-', 2, opAssoc.LEFT, lambda a: a[0][0] - a[0][2]),
+ ('*', 2, opAssoc.LEFT, lambda a: a[0][0] * a[0][2]),
+ ('/', 2, opAssoc.LEFT, lambda a: a[0][0] / a[0][2]),
+ ]
+ )
+
+ def _parseVariable(self, var):
+ logging.debug(f'Looking up variable "{var[0]}"')
+ endpoint, key = var[0].split('.',1)
+
+ if not endpoint in self._endpoints:
+ logging.error(f'Parser: Endpoint "{endpoint}" not found')
+ return None
+
+ return self._endpoints[endpoint].getState(key)
+
+ def _evaluate(self, action: str) -> bool:
+ return all(self._parser.parse_string(str(s)) for s in self._instances[action]['args']['when'])
+
diff --git a/triggers.yml b/triggers.yml
new file mode 100644
index 0000000..9595f7e
--- /dev/null
+++ b/triggers.yml
@@ -0,0 +1,10 @@
+conditional:
+ class: trigger.ConditionalTrigger
+
+#mqtt:
+# class: trigger.Mqtt
+# server: asdf
+# user: OwO
+#
+#timer:
+# class: trigger.Timer