refactor!: fix task decorator and option adding

This commit is contained in:
Daylin Morgan 2024-10-23 15:34:09 -05:00
parent 0bc9138da3
commit d26ecfa7d5
Signed by: daylin
GPG key ID: 950D13E9719334AD

View file

@ -1,21 +1,26 @@
from __future__ import annotations
import argparse import argparse
import inspect import inspect
import os import os
import shlex import shlex
import shutil import shutil
import subprocess import signal
import sys import sys
from argparse import ( from argparse import (
Action,
ArgumentParser, ArgumentParser,
RawDescriptionHelpFormatter, RawDescriptionHelpFormatter,
_SubParsersAction,
) )
from functools import wraps from functools import wraps
from inspect import Parameter from inspect import Parameter
from pathlib import Path from pathlib import Path
from subprocess import PIPE, CompletedProcess, Popen from subprocess import PIPE, Popen
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING
if TYPE_CHECKING:
from argparse import Action, _SubParsersAction
from subprocess import CompletedProcess
from typing import Any, Callable, Dict, List, Optional, Tuple
__version__ = "0.1.0" __version__ = "0.1.0"
@ -48,7 +53,7 @@ def _id_from_func(f: Callable[..., Any]):
class Task: class Task:
def __init__( def __init__(
self, func=Callable[..., Any], name: Optional[str] = None, show: bool = False self, func: Callable[..., Any], name: Optional[str] = None, show: bool = False
) -> None: ) -> None:
self.show = show self.show = show
self.id = _id_from_func(func) self.id = _id_from_func(func)
@ -64,12 +69,14 @@ class Task:
for name, param in self.signature.parameters.items(): for name, param in self.signature.parameters.items():
self.params[name] = {"Parameter": param} self.params[name] = {"Parameter": param}
def _update_option(self, name: str, help: str, **kwargs) -> None: def _update_option(self, name: str, help: str, short: str, **kwargs) -> None:
self.params[name] = { self.params[name] = {
**self.params.get(name, {}), **self.params.get(name, {}),
"help": help, "help": help,
"kwargs": kwargs, "kwargs": kwargs,
} }
if short != "":
self.params[name]["short"] = short
def _mark(self) -> None: def _mark(self) -> None:
self.show = True self.show = True
@ -116,19 +123,18 @@ class Context:
self._tasks[id_]._mark() self._tasks[id_]._mark()
return id_ return id_
def _update_option(self, func: Callable[..., Any], name: str, help: str, **kwargs): def _update_option(
if (id_ := _id_from_func(func)) not in self._tasks: self, func: Callable[..., Any], name: str, help: str, short: str, **kwargs
raise ValueError ):
self._tasks[id_]._update_option(name, help, **kwargs) id_ = self._add_task(func)
self._tasks[id_]._update_option(name, help, short, **kwargs)
def _add_target(self, func: Callable[..., Any], target: str) -> None: def _add_target(self, func: Callable[..., Any], target: str) -> None:
self._add_task(func) id_ = self._add_task(func)
id_ = _id_from_func(func)
self._tasks[id_].targets.append(target) self._tasks[id_].targets.append(target)
def _add_need(self, func: Callable[..., Any], need: str) -> None: def _add_need(self, func: Callable[..., Any], need: str) -> None:
self._add_task(func) id_ = self._add_task(func)
id_ = _id_from_func(func)
self._tasks[id_].needs.append(need) self._tasks[id_].needs.append(need)
def _generate_graph(self) -> None: def _generate_graph(self) -> None:
@ -151,10 +157,6 @@ class Context:
ctx = Context() ctx = Context()
def setenv(key: str, value: str) -> None:
ctx._env.update({key: value})
class SwyddSubResult: class SwyddSubResult:
def __init__( def __init__(
self, self,
@ -205,9 +207,6 @@ class SwyddProc:
elif isinstance(proc, SwyddProc): elif isinstance(proc, SwyddProc):
return SwyddPipe(proc) return SwyddPipe(proc)
# def __or__(self, proc: "str | SwyddProc") -> "SwyddPipe | SwyddProc":
# return self.pipe(proc)
def then(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq": def then(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
if self._cmd: if self._cmd:
return SwyddSeq(self, proc) return SwyddSeq(self, proc)
@ -220,9 +219,6 @@ class SwyddProc:
else: else:
return SwyddSeq(SwyddProc(proc)) return SwyddSeq(SwyddProc(proc))
# def __and__(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
# return self.then(proc)
def _build_kwargs(self) -> Dict[str, Any]: def _build_kwargs(self) -> Dict[str, Any]:
sub_kwargs: Dict[str, Any] = dict(env={**os.environ, **ctx._env}) sub_kwargs: Dict[str, Any] = dict(env={**os.environ, **ctx._env})
@ -238,9 +234,19 @@ class SwyddProc:
sys.stdout.write(f"swydd exec | {self._cmd}\n") sys.stdout.write(f"swydd exec | {self._cmd}\n")
self.output = self.output or output self.output = self.output or output
return SwyddSubResult.from_completed_process(
subprocess.run(self.cmd, **self._build_kwargs()) p = Popen(self.cmd, **self._build_kwargs())
)
try:
out, err = p.communicate()
except KeyboardInterrupt:
sys.stderr.write("forwarding CTRL+C\n")
sys.stderr.flush()
p.send_signal(signal.SIGINT)
p.wait()
out, err = p.communicate()
return SwyddSubResult.from_popen(p, out, err)
def check(self) -> bool: def check(self) -> bool:
return self.execute().code == 0 return self.execute().code == 0
@ -283,15 +289,21 @@ class SwyddPipe:
if p.stdout: if p.stdout:
p.stdout.close() p.stdout.close()
try:
out, err = procs[-1].communicate() out, err = procs[-1].communicate()
except KeyboardInterrupt:
sys.stderr.write("forwarding CTRL+C\n")
sys.stderr.flush()
# ALL of them?
procs[-1].send_signal(signal.SIGINT)
procs[-1].wait()
out, err = procs[-1].communicate()
return SwyddSubResult.from_popen(procs[-1], out, err) return SwyddSubResult.from_popen(procs[-1], out, err)
def pipe(self, proc: "str | SwyddProc | SwyddPipe") -> "SwyddPipe": def pipe(self, proc: "str | SwyddProc | SwyddPipe") -> "SwyddPipe":
return SwyddPipe(self, proc) return SwyddPipe(self, proc)
# def __or__(self, proc: "str | SwyddProc | SwyddPipe") -> "SwyddPipe":
# return self.pipe(proc)
class SwyddSeq: class SwyddSeq:
def __init__(self, *procs: "str | SwyddProc | SwyddSeq") -> None: def __init__(self, *procs: "str | SwyddProc | SwyddSeq") -> None:
@ -311,9 +323,6 @@ class SwyddSeq:
def then(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq": def then(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
return SwyddSeq(*self._procs, proc) return SwyddSeq(*self._procs, proc)
# def __and__(self, proc: "str | SwyddProc | SwyddSeq") -> "SwyddSeq":
# return self.then(proc)
def execute(self, output: bool = False) -> "SwyddSubResult": def execute(self, output: bool = False) -> "SwyddSubResult":
results = [] results = []
for proc in self._procs: for proc in self._procs:
@ -357,12 +366,6 @@ class SwyddGet:
output += result.stderr.strip() output += result.stderr.strip()
return output return output
# def __lt__(self, proc: str | SwyddProc | SwyddPipe | SwyddSeq) -> str:
# return self.__call__(proc)
#
# def __lshift__(self, proc: str | SwyddProc | SwyddPipe | SwyddSeq) -> str:
# return self.__call__(proc, stdout=False, stderr=True)
def _get_caller_path() -> Path: def _get_caller_path() -> Path:
# NOTE: jupyter will hate this code I'm sure # NOTE: jupyter will hate this code I'm sure
@ -386,8 +389,6 @@ class SwyddSub:
raise ValueError(f"unspported type: {type(exec)}") raise ValueError(f"unspported type: {type(exec)}")
# TODO: change alias to not confuse with pathlib.Path?
# asset / file ... partial to asset
class SwyddPath: class SwyddPath:
_root = None _root = None
_path = None _path = None
@ -476,15 +477,6 @@ class SwyddPath:
return self._append_text(src.read().strip()) return self._append_text(src.read().strip())
def task(func: Callable[..., Any]) -> Callable[..., None]:
ctx._add_task(func, show=True)
def wrap(*args: Any, **kwargs: Any) -> None:
return func(*args, **kwargs)
return wrap
def _inspect_wrapper(place, func): def _inspect_wrapper(place, func):
if wrapped := getattr(func, "__wrapped__", None): if wrapped := getattr(func, "__wrapped__", None):
print(place, "wrapped->", id(wrapped)) print(place, "wrapped->", id(wrapped))
@ -495,13 +487,12 @@ def _inspect_wrapper(place, func):
) )
def task2( def task(
hidden: bool = False, arg=None,
) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]: ):
def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]: def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]:
ctx._add_task(func, show=True) ctx._add_task(func, show=not func.__name__.startswith("_"))
# _inspect_wrapper("task", func)
_inspect_wrapper("task", func)
@wraps(func) @wraps(func)
def inner(*args: Any, **kwargs: Any) -> Callable[..., None]: def inner(*args: Any, **kwargs: Any) -> Callable[..., None]:
@ -509,6 +500,9 @@ def task2(
return inner return inner
if callable(arg):
return wrapper(arg)
else:
return wrapper return wrapper
@ -516,8 +510,6 @@ def targets(
*args: str, *args: str,
) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]: ) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]: def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]:
_inspect_wrapper("targets", func)
ctx._add_task(func)
for arg in args: for arg in args:
ctx._add_target(func, arg) ctx._add_target(func, arg)
ctx.targets[arg] = _id_from_func(func) ctx.targets[arg] = _id_from_func(func)
@ -550,11 +542,11 @@ def needs(
def option( def option(
name: str, name: str,
help: str, help: str,
short: str = "",
**help_kwargs: str, **help_kwargs: str,
) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]: ) -> Callable[[Callable[..., Any]], Callable[..., Callable[..., None]]]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]: def wrapper(func: Callable[..., Any]) -> Callable[..., Callable[..., None]]:
ctx._add_task(func) ctx._update_option(func, name.replace("-", "_"), help, short, **help_kwargs)
ctx._update_option(func, name.replace("-", "_"), help, **help_kwargs)
@wraps(func) @wraps(func)
def inner(*args: Any, **kwargs: Any) -> Callable[..., None]: def inner(*args: Any, **kwargs: Any) -> Callable[..., None]:
@ -606,13 +598,17 @@ def _generate_task_subparser(
subparsers: _SubParsersAction, subparsers: _SubParsersAction,
task: Task, task: Task,
target: Optional[str] = None, target: Optional[str] = None,
doc: str = "",
) -> Optional[ArgumentParser]: ) -> Optional[ArgumentParser]:
# TODO: don't return an option
if not task.show and not target: if not task.show and not target:
return return
prog = os.path.basename(sys.argv[0]) prog = os.path.basename(sys.argv[0])
name = task.name if not target else target name = task.name if not target else target
doc = task.func.__doc__.splitlines()[0] if task.func.__doc__ else "" if doc == "" and task.func.__doc__:
doc = task.func.__doc__.splitlines()[0]
# doc = task.func.__doc__.splitlines()[0] if task.func.__doc__ else ""
subparser = subparsers.add_parser( subparser = subparsers.add_parser(
name.replace("_", "-"), name.replace("_", "-"),
help=doc, help=doc,
@ -623,8 +619,10 @@ def _generate_task_subparser(
) )
for name, info in task.params.items(): for name, info in task.params.items():
param = info.get("Parameter") # must check signature for args? param = info.get("Parameter") # must check signature for args?
args = []
args = (f"--{name.replace('_','-')}",) if "short" in info:
args.append("-" + info["short"])
args.append(f"--{name.replace('_','-')}")
kwargs = {"help": info.get("help", "")} kwargs = {"help": info.get("help", "")}
if param.annotation is bool: if param.annotation is bool:
@ -640,26 +638,55 @@ def _generate_task_subparser(
kwargs.update(info.get("kwargs", {})) kwargs.update(info.get("kwargs", {}))
subparser.add_argument(*args, **kwargs) subparser.add_argument(*args, **kwargs)
# TODO: properly build out a dag from tasks "needs"
# for now add a simple check for existense
# for need in task.needs:
# asset(need)._check()
def executor(*args, **kwargs):
for need in task.needs:
asset(need)._check()
f = ( f = (
target_generator(target, ctx._graph.nodes[target])(task.func) target_generator(target, ctx._graph.nodes[target])(task.func)
if target if target
else task.func else task.func
) )
subparser.set_defaults(func=f) return f(*args, **kwargs)
subparser.set_defaults(func=executor)
return subparser return subparser
def _target_status(target: str) -> str:
if not (target_path := Path(target)).is_file():
return "missing target"
needs = ctx._graph.nodes[target]
target_stat = target_path.stat()
needs_stats = []
for need in needs:
if not (p := Path(need)).is_file():
return "missing inputs!"
needs_stats.append(p)
if any((stat.st_mtime > target_stat.st_mtime for stat in needs_stats)):
return "out of date"
return " "
def _add_targets( def _add_targets(
shared: ArgumentParser, subparsers: _SubParsersAction, ctx: Context shared: ArgumentParser, subparsers: _SubParsersAction, ctx: Context
) -> None: ) -> None:
for target, id_ in ctx.targets.items(): for target, id_ in ctx.targets.items():
subp = _generate_task_subparser( subp = _generate_task_subparser(
shared, subparsers, ctx._tasks[id_], str(target) shared, subparsers, ctx._tasks[id_], str(target), doc=_target_status(target)
) )
if subp: if subp:
subp.add_argument("--dag", help="show target dag", action="store_true") subp.add_argument("--dag", help="show target dag", action="store_true")
subp.add_argument("--force", help="force execution", action="store_true") subp.add_argument(
"-f", "--force", help="force execution", action="store_true"
)
def _task_repr(func: Callable) -> str: def _task_repr(func: Callable) -> str:
@ -707,7 +734,7 @@ def cli(default: str | None = None) -> None:
if len(sys.argv) == 1: if len(sys.argv) == 1:
if default: if default:
sys.argv.append(default) sys.argv.extend(shlex.split(default))
else: else:
parser.print_help(sys.stderr) parser.print_help(sys.stderr)
sys.exit(1) sys.exit(1)
@ -730,6 +757,8 @@ def cli(default: str | None = None) -> None:
if f := args.pop("func", None): if f := args.pop("func", None):
if ctx.dry: if ctx.dry:
sys.stderr.write("dry run >>>\n" f" args: {args}\n") sys.stderr.write("dry run >>>\n" f" args: {args}\n")
if ctx._env:
sys.stderr.write(f" env: {ctx._env}\n")
sys.stderr.write(_task_repr(f)) sys.stderr.write(_task_repr(f))
elif ctx.dag: elif ctx.dag:
sys.stderr.write( sys.stderr.write(
@ -746,7 +775,7 @@ def cli(default: str | None = None) -> None:
seq, seq,
sub, sub,
get, get,
path, asset,
) = ( ) = (
SwyddProc(), SwyddProc(),
SwyddPipe(), SwyddPipe(),
@ -756,8 +785,6 @@ def cli(default: str | None = None) -> None:
SwyddPath(), SwyddPath(),
) )
asset = SwyddPath()
def geterr(*args, **kwargs) -> str: def geterr(*args, **kwargs) -> str:
get_kwargs = dict(stderr=True, stdout=False) get_kwargs = dict(stderr=True, stdout=False)
@ -765,6 +792,24 @@ def geterr(*args, **kwargs) -> str:
return get(*args, **get_kwargs) return get(*args, **get_kwargs)
def setenv(key: str, value: str) -> None:
ctx._env.update({key: value})
__all__ = [
"proc",
"pipe",
"seq",
"sub",
"get",
"asset",
"ctx",
"geterr",
"setenv",
"cli",
"task",
]
if __name__ == "__main__": if __name__ == "__main__":
sys.stderr.write("this module should not be invoked directly\n") sys.stderr.write("this module should not be invoked directly\n")
sys.exit(1) sys.exit(1)