refactor!: fix task decorator and option adding

This commit is contained in:
Daylin Morgan 2024-10-23 15:34:09 -05:00
parent 0bc9138da3
commit d26ecfa7d5
Signed by: daylin
GPG key ID: 950D13E9719334AD

View file

@ -1,21 +1,26 @@
from __future__ import annotations
import argparse
import inspect
import os
import shlex
import shutil
import subprocess
import signal
import sys
from argparse import (
Action,
ArgumentParser,
RawDescriptionHelpFormatter,
_SubParsersAction,
)
from functools import wraps
from inspect import Parameter
from pathlib import Path
from subprocess import PIPE, CompletedProcess, Popen
from typing import Any, Callable, Dict, List, Optional, Tuple
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from argparse import Action, _SubParsersAction
from subprocess import CompletedProcess
from typing import Any, Callable, Dict, List, Optional, Tuple
__version__ = "0.1.0"
@ -48,7 +53,7 @@ def _id_from_func(f: Callable[..., Any]):
class Task:
def __init__(
self, func=Callable[..., Any], name: Optional[str] = None, show: bool = False
self, func: Callable[..., Any], name: Optional[str] = None, show: bool = False
) -> None:
self.show = show
self.id = _id_from_func(func)
@ -64,12 +69,14 @@ class Task:
for name, param in self.signature.parameters.items():
self.params[name] = {"Parameter": param}
def _update_option(self, name: str, help: str, **kwargs) -> None:
def _update_option(self, name: str, help: str, short: str, **kwargs) -> None:
self.params[name] = {
**self.params.get(name, {}),
"help": help,
"kwargs": kwargs,
}
if short != "":
self.params[name]["short"] = short
def _mark(self) -> None:
self.show = True
@ -116,19 +123,18 @@ class Context:
self._tasks[id_]._mark()
return id_
def _update_option(self, func: Callable[..., Any], name: str, help: str, **kwargs):
if (id_ := _id_from_func(func)) not in self._tasks:
raise ValueError
self._tasks[id_]._update_option(name, help, **kwargs)
def _update_option(
self, func: Callable[..., Any], name: str, help: str, short: str, **kwargs
):
id_ = self._add_task(func)
self._tasks[id_]._update_option(name, help, short, **kwargs)
def _add_target(self, func: Callable[..., Any], target: str) -> None:
self._add_task(func)
id_ = _id_from_func(func)
id_ = self._add_task(func)
self._tasks[id_].targets.append(target)
def _add_need(self, func: Callable[..., Any], need: str) -> None:
self._add_task(func)
id_ = _id_from_func(func)
id_ = self._add_task(func)
self._tasks[id_].needs.append(need)
def _generate_graph(self) -> None:
@ -151,10 +157,6 @@ class Context:
ctx = Context()
def setenv(key: str, value: str) -> None:
ctx._env.update({key: value})
class SwyddSubResult:
def __init__(
self,
@ -205,9 +207,6 @@ class SwyddProc:
elif isinstance(proc, SwyddProc):
return SwyddPipe(proc)
# def __or__(self, proc: "str | SwyddProc") -> "SwyddPipe | SwyddProc":
# return self.pipe(proc)
def then(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
if self._cmd:
return SwyddSeq(self, proc)
@ -220,9 +219,6 @@ class SwyddProc:
else:
return SwyddSeq(SwyddProc(proc))
# def __and__(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
# return self.then(proc)
def _build_kwargs(self) -> Dict[str, Any]:
sub_kwargs: Dict[str, Any] = dict(env={**os.environ, **ctx._env})
@ -238,9 +234,19 @@ class SwyddProc:
sys.stdout.write(f"swydd exec | {self._cmd}\n")
self.output = self.output or output
return SwyddSubResult.from_completed_process(
subprocess.run(self.cmd, **self._build_kwargs())
)
p = Popen(self.cmd, **self._build_kwargs())
try:
out, err = p.communicate()
except KeyboardInterrupt:
sys.stderr.write("forwarding CTRL+C\n")
sys.stderr.flush()
p.send_signal(signal.SIGINT)
p.wait()
out, err = p.communicate()
return SwyddSubResult.from_popen(p, out, err)
def check(self) -> bool:
return self.execute().code == 0
@ -283,15 +289,21 @@ class SwyddPipe:
if p.stdout:
p.stdout.close()
out, err = procs[-1].communicate()
try:
out, err = procs[-1].communicate()
except KeyboardInterrupt:
sys.stderr.write("forwarding CTRL+C\n")
sys.stderr.flush()
# ALL of them?
procs[-1].send_signal(signal.SIGINT)
procs[-1].wait()
out, err = procs[-1].communicate()
return SwyddSubResult.from_popen(procs[-1], out, err)
def pipe(self, proc: "str | SwyddProc | SwyddPipe") -> "SwyddPipe":
return SwyddPipe(self, proc)
# def __or__(self, proc: "str | SwyddProc | SwyddPipe") -> "SwyddPipe":
# return self.pipe(proc)
class SwyddSeq:
def __init__(self, *procs: "str | SwyddProc | SwyddSeq") -> None:
@ -311,9 +323,6 @@ class SwyddSeq:
def then(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
return SwyddSeq(*self._procs, proc)
# def __and__(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
# return self.then(proc)
def execute(self, output: bool = False) -> "SwyddSubResult":
results = []
for proc in self._procs:
@ -357,12 +366,6 @@ class SwyddGet:
output += result.stderr.strip()
return output
# def __lt__(self, proc: str | SwyddProc | SwyddPipe | SwyddSeq) -> str:
# return self.__call__(proc)
#
# def __lshift__(self, proc: str | SwyddProc | SwyddPipe | SwyddSeq) -> str:
# return self.__call__(proc, stdout=False, stderr=True)
def _get_caller_path() -> Path:
# NOTE: jupyter will hate this code I'm sure
@ -386,8 +389,6 @@ class SwyddSub:
raise ValueError(f"unspported type: {type(exec)}")
# TODO: change alias to not confuse with pathlib.Path?
# asset / file ... partial to asset
class SwyddPath:
_root = None
_path = None
@ -476,15 +477,6 @@ class SwyddPath:
return self._append_text(src.read().strip())
def task(func: Callable[..., Any]) -> Callable[..., None]:
ctx._add_task(func, show=True)
def wrap(*args: Any, **kwargs: Any) -> None:
return func(*args, **kwargs)
return wrap
def _inspect_wrapper(place, func):
if wrapped := getattr(func, "__wrapped__", None):
print(place, "wrapped->", id(wrapped))
@ -495,13 +487,12 @@ def _inspect_wrapper(place, func):
)
def task2(
hidden: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]:
def task(
arg=None,
):
def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]:
ctx._add_task(func, show=True)
_inspect_wrapper("task", func)
ctx._add_task(func, show=not func.__name__.startswith("_"))
# _inspect_wrapper("task", func)
@wraps(func)
def inner(*args: Any, **kwargs: Any) -> Callable[..., None]:
@ -509,15 +500,16 @@ def task2(
return inner
return wrapper
if callable(arg):
return wrapper(arg)
else:
return wrapper
def targets(
*args: str,
) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]:
_inspect_wrapper("targets", func)
ctx._add_task(func)
for arg in args:
ctx._add_target(func, arg)
ctx.targets[arg] = _id_from_func(func)
@ -550,11 +542,11 @@ def needs(
def option(
name: str,
help: str,
short: str = "",
**help_kwargs: str,
) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]:
ctx._add_task(func)
ctx._update_option(func, name.replace("-", "_"), help, **help_kwargs)
ctx._update_option(func, name.replace("-", "_"), help, short, **help_kwargs)
@wraps(func)
def inner(*args: Any, **kwargs: Any) -> Callable[..., None]:
@ -606,13 +598,17 @@ def _generate_task_subparser(
subparsers: _SubParsersAction,
task: Task,
target: Optional[str] = None,
doc: str = "",
) -> Optional[ArgumentParser]:
# TODO: don't return an option
if not task.show and not target:
return
prog = os.path.basename(sys.argv[0])
name = task.name if not target else target
doc = task.func.__doc__.splitlines()[0] if task.func.__doc__ else ""
if doc == "" and task.func.__doc__:
doc = task.func.__doc__.splitlines()[0]
# doc = task.func.__doc__.splitlines()[0] if task.func.__doc__ else ""
subparser = subparsers.add_parser(
name.replace("_", "-"),
help=doc,
@ -623,8 +619,10 @@ def _generate_task_subparser(
)
for name, info in task.params.items():
param = info.get("Parameter") # must check signature for args?
args = (f"--{name.replace('_','-')}",)
args = []
if "short" in info:
args.append("-" + info["short"])
args.append(f"--{name.replace('_','-')}")
kwargs = {"help": info.get("help", "")}
if param.annotation is bool:
@ -640,26 +638,55 @@ def _generate_task_subparser(
kwargs.update(info.get("kwargs", {}))
subparser.add_argument(*args, **kwargs)
f = (
target_generator(target, ctx._graph.nodes[target])(task.func)
if target
else task.func
)
subparser.set_defaults(func=f)
# TODO: properly build out a dag from tasks "needs"
# for now add a simple check for existense
# for need in task.needs:
# asset(need)._check()
def executor(*args, **kwargs):
for need in task.needs:
asset(need)._check()
f = (
target_generator(target, ctx._graph.nodes[target])(task.func)
if target
else task.func
)
return f(*args, **kwargs)
subparser.set_defaults(func=executor)
return subparser
def _target_status(target: str) -> str:
if not (target_path := Path(target)).is_file():
return "missing target"
needs = ctx._graph.nodes[target]
target_stat = target_path.stat()
needs_stats = []
for need in needs:
if not (p := Path(need)).is_file():
return "missing inputs!"
needs_stats.append(p)
if any((stat.st_mtime > target_stat.st_mtime for stat in needs_stats)):
return "out of date"
return " "
def _add_targets(
shared: ArgumentParser, subparsers: _SubParsersAction, ctx: Context
) -> None:
for target, id_ in ctx.targets.items():
subp = _generate_task_subparser(
shared, subparsers, ctx._tasks[id_], str(target)
shared, subparsers, ctx._tasks[id_], str(target), doc=_target_status(target)
)
if subp:
subp.add_argument("--dag", help="show target dag", action="store_true")
subp.add_argument("--force", help="force execution", action="store_true")
subp.add_argument(
"-f", "--force", help="force execution", action="store_true"
)
def _task_repr(func: Callable) -> str:
@ -707,7 +734,7 @@ def cli(default: str | None = None) -> None:
if len(sys.argv) == 1:
if default:
sys.argv.append(default)
sys.argv.extend(shlex.split(default))
else:
parser.print_help(sys.stderr)
sys.exit(1)
@ -730,6 +757,8 @@ def cli(default: str | None = None) -> None:
if f := args.pop("func", None):
if ctx.dry:
sys.stderr.write("dry run >>>\n" f" args: {args}\n")
if ctx._env:
sys.stderr.write(f" env: {ctx._env}\n")
sys.stderr.write(_task_repr(f))
elif ctx.dag:
sys.stderr.write(
@ -746,7 +775,7 @@ def cli(default: str | None = None) -> None:
seq,
sub,
get,
path,
asset,
) = (
SwyddProc(),
SwyddPipe(),
@ -756,8 +785,6 @@ def cli(default: str | None = None) -> None:
SwyddPath(),
)
asset = SwyddPath()
def geterr(*args, **kwargs) -> str:
get_kwargs = dict(stderr=True, stdout=False)
@ -765,6 +792,24 @@ def geterr(*args, **kwargs) -> str:
return get(*args, **get_kwargs)
def setenv(key: str, value: str) -> None:
ctx._env.update({key: value})
__all__ = [
"proc",
"pipe",
"seq",
"sub",
"get",
"asset",
"ctx",
"geterr",
"setenv",
"cli",
"task",
]
if __name__ == "__main__":
sys.stderr.write("this module should not be invoked directly\n")
sys.exit(1)