diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0c15c83 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,16 @@ +name: CI + +on: + push: + pull_request: + +jobs: + test: + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: Run tests + run: go test -p 1 ./... diff --git a/manage.go b/manage.go index 68b8379..f398192 100644 --- a/manage.go +++ b/manage.go @@ -72,13 +72,13 @@ func ConnectWithOptions(serverName, domain, username, password string) (TaskServ if !taskService.isInitialized { err = taskService.initialize() if err != nil { - return TaskService{}, fmt.Errorf("error initializing ITaskService object: %v", err) + return TaskService{}, fmt.Errorf("error initializing ITaskService object: %w", err) } } _, err = oleutil.CallMethod(taskService.taskServiceObj, "Connect", serverName, username, domain, password) if err != nil { - return TaskService{}, fmt.Errorf("error connecting to Task Scheduler service: %v", getTaskSchedulerError(err)) + return TaskService{}, fmt.Errorf("error connecting to Task Scheduler service: %w", getTaskSchedulerError(err)) } if serverName == "" { @@ -103,7 +103,7 @@ func ConnectWithOptions(serverName, domain, username, password string) (TaskServ res, err := oleutil.CallMethod(taskService.taskServiceObj, "GetFolder", `\`) if err != nil { - return TaskService{}, fmt.Errorf("error getting the root folder: %v", getTaskSchedulerError(err)) + return TaskService{}, fmt.Errorf("error getting the root folder: %w", getTaskSchedulerError(err)) } taskService.rootFolderObj = res.ToIDispatch() taskService.isConnected = true @@ -133,9 +133,8 @@ func (t *TaskService) GetRunningTasks() (RunningTaskCollection, error) { res, err := oleutil.CallMethod(t.taskServiceObj, "GetRunningTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return nil, fmt.Errorf("error getting running tasks: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error getting running tasks: %w", getTaskSchedulerError(err)) } - defer res.Clear() runningTasksObj := res.ToIDispatch() defer runningTasksObj.Release() @@ -144,7 +143,7 @@ func (t *TaskService) GetRunningTasks() (RunningTaskCollection, error) { runningTask, err := parseRunningTask(task) if err != nil { - return fmt.Errorf("error parsing running task: %v", err) + return fmt.Errorf("error parsing running task: %w", err) } runningTasks = append(runningTasks, runningTask) @@ -167,7 +166,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { // get tasks from root folder res, err := oleutil.CallMethod(t.rootFolderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return nil, fmt.Errorf("error getting tasks of root folder: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error getting tasks of root folder: %w", getTaskSchedulerError(err)) } rootTaskCollection := res.ToIDispatch() defer rootTaskCollection.Release() @@ -177,7 +176,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } registeredTasks = append(registeredTasks, registeredTask) @@ -189,7 +188,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { res, err = oleutil.CallMethod(t.rootFolderObj, "GetFolders", 0) if err != nil { - return nil, fmt.Errorf("error getting task folders of root folder: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error getting task folders of root folder: %w", getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -202,7 +201,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { res, err := oleutil.CallMethod(taskFolder, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting tasks of folder: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting tasks of folder: %w", getTaskSchedulerError(err)) } taskCollection := res.ToIDispatch() defer taskCollection.Release() @@ -212,7 +211,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } registeredTasks = append(registeredTasks, registeredTask) @@ -224,7 +223,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { res, err = oleutil.CallMethod(taskFolder, "GetFolders", 0) if err != nil { - return fmt.Errorf("error getting subfolders of folder: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting subfolders of folder: %w", getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -255,12 +254,12 @@ func (t *TaskService) GetRegisteredTask(path string) (RegisteredTask, error) { taskObj, err := oleutil.CallMethod(t.rootFolderObj, "GetTask", path) if err != nil { - return RegisteredTask{}, fmt.Errorf("error getting registered task %s: %v", path, getTaskSchedulerError(err)) + return RegisteredTask{}, fmt.Errorf("error getting registered task %s: %w", path, getTaskSchedulerError(err)) } task, _, err := parseRegisteredTask(taskObj.ToIDispatch()) if err != nil { - return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %v", path, err) + return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %w", path, err) } return task, nil @@ -286,7 +285,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { } else { topFolder, err := oleutil.CallMethod(t.taskServiceObj, "GetFolder", path) if err != nil { - return TaskFolder{}, fmt.Errorf("error getting folder %s: %v", path, getTaskSchedulerError(err)) + return TaskFolder{}, fmt.Errorf("error getting folder %s: %w", path, getTaskSchedulerError(err)) } topFolderObj = topFolder.ToIDispatch() defer topFolderObj.Release() @@ -295,7 +294,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { // get tasks from the top folder res, err := oleutil.CallMethod(topFolderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return TaskFolder{}, fmt.Errorf("error getting tasks of folder %s: %v", path, getTaskSchedulerError(err)) + return TaskFolder{}, fmt.Errorf("error getting tasks of folder %s: %w", path, getTaskSchedulerError(err)) } topFolderTaskCollection := res.ToIDispatch() defer topFolderTaskCollection.Release() @@ -305,7 +304,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } topFolder.RegisteredTasks = append(topFolder.RegisteredTasks, registeredTask) @@ -317,7 +316,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { res, err = oleutil.CallMethod(topFolderObj, "GetFolders", 0) if err != nil { - return TaskFolder{}, fmt.Errorf("error getting subfolders of folder %s: %v", path, getTaskSchedulerError(err)) + return TaskFolder{}, fmt.Errorf("error getting subfolders of folder %s: %w", path, getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -334,7 +333,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { path := oleutil.MustGetProperty(taskFolder, "Path").ToString() res, err := oleutil.CallMethod(taskFolder, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting tasks of folder %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error getting tasks of folder %s: %w", path, getTaskSchedulerError(err)) } taskCollection := res.ToIDispatch() defer taskCollection.Release() @@ -349,7 +348,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } taskSubFolder.RegisteredTasks = append(taskSubFolder.RegisteredTasks, registeredTask) @@ -363,7 +362,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { res, err = oleutil.CallMethod(taskFolder, "GetFolders", 0) if err != nil { - return fmt.Errorf("error getting subfolders of folder %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error getting subfolders of folder %s: %w", path, getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -446,7 +445,7 @@ func (t *TaskService) CreateTaskEx(path string, newTaskDef Definition, username, if !t.taskFolderExist(folderPath) { _, err = oleutil.CallMethod(t.rootFolderObj, "CreateFolder", folderPath, "") if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error creating folder %s: %v", path, getTaskSchedulerError(err)) + return RegisteredTask{}, false, fmt.Errorf("error creating folder %s: %w", path, getTaskSchedulerError(err)) } } else { if t.registeredTaskExist(path) { @@ -460,19 +459,19 @@ func (t *TaskService) CreateTaskEx(path string, newTaskDef Definition, username, } _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteTask", path, 0) if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error deleting registered task %s: %v", path, getTaskSchedulerError(err)) + return RegisteredTask{}, false, fmt.Errorf("error deleting registered task %s: %w", path, getTaskSchedulerError(err)) } } } newTaskObj, err := t.modifyTask(path, newTaskDef, username, password, logonType, TASK_CREATE) if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error creating registered task %s: %v", path, err) + return RegisteredTask{}, false, fmt.Errorf("error creating registered task %s: %w", path, err) } newTask, _, err := parseRegisteredTask(newTaskObj) if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error parsing registered task %s: %v", path, err) + return RegisteredTask{}, false, fmt.Errorf("error parsing registered task %s: %w", path, err) } return newTask, true, nil @@ -495,13 +494,13 @@ func (t *TaskService) UpdateTaskEx(path string, newTaskDef Definition, username, newTaskObj, err := t.modifyTask(path, newTaskDef, username, password, logonType, TASK_UPDATE) if err != nil { - return RegisteredTask{}, fmt.Errorf("error updating %s task: %v", path, err) + return RegisteredTask{}, fmt.Errorf("error updating %s task: %w", path, err) } // update the internal database of registered tasks newTask, _, err := parseRegisteredTask(newTaskObj) if err != nil { - return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %v", path, err) + return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %w", path, err) } return newTask, nil @@ -515,19 +514,19 @@ func (t *TaskService) modifyTask(path string, newTaskDef Definition, username, p res, err := oleutil.CallMethod(t.taskServiceObj, "NewTask", 0) if err != nil { - return nil, fmt.Errorf("error creating new task: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error creating new task: %w", getTaskSchedulerError(err)) } newTaskDefObj := res.ToIDispatch() defer newTaskDefObj.Release() err = fillDefinitionObj(newTaskDef, newTaskDefObj) if err != nil { - return nil, fmt.Errorf("error filling ITaskDefinition: %v", err) + return nil, fmt.Errorf("error filling ITaskDefinition: %w", err) } newTaskObj, err := oleutil.CallMethod(t.rootFolderObj, "RegisterTaskDefinition", path, newTaskDefObj, int(flags), username, password, int(logonType), "") if err != nil { - return nil, fmt.Errorf("error registering task: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error registering task: %w", getTaskSchedulerError(err)) } return newTaskObj.ToIDispatch(), nil @@ -545,14 +544,14 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e taskFolder, err := oleutil.CallMethod(t.taskServiceObj, "GetFolder", path) if err != nil { - return false, fmt.Errorf("error getting folder: %v", getTaskSchedulerError(err)) + return false, fmt.Errorf("error getting folder: %w", getTaskSchedulerError(err)) } taskFolderObj := taskFolder.ToIDispatch() defer taskFolderObj.Release() res, err := oleutil.CallMethod(taskFolderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return false, fmt.Errorf("error getting tasks of folder: %v", getTaskSchedulerError(err)) + return false, fmt.Errorf("error getting tasks of folder: %w", getTaskSchedulerError(err)) } taskCollection := res.ToIDispatch() defer taskCollection.Release() @@ -562,7 +561,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e res, err = oleutil.CallMethod(taskFolderObj, "GetFolders", int(TASK_ENUM_HIDDEN)) if err != nil { - return false, fmt.Errorf("error getting the subfolders: %v", getTaskSchedulerError(err)) + return false, fmt.Errorf("error getting the subfolders: %w", getTaskSchedulerError(err)) } folderCollection := res.ToIDispatch() defer folderCollection.Release() @@ -594,7 +593,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e res, err := oleutil.CallMethod(folderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting tasks of folder: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting tasks of folder: %w", getTaskSchedulerError(err)) } tasks := res.ToIDispatch() defer tasks.Release() @@ -606,7 +605,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e res, err = oleutil.CallMethod(folderObj, "GetFolders", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting subfolders: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting subfolders: %w", getTaskSchedulerError(err)) } subFolders := res.ToIDispatch() defer subFolders.Release() @@ -619,7 +618,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e currentFolderPath := oleutil.MustGetProperty(folderObj, "Path").ToString() _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteFolder", currentFolderPath, 0) if err != nil { - return fmt.Errorf("error deleting task folder %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error deleting task folder %s: %w", path, getTaskSchedulerError(err)) } return nil @@ -635,7 +634,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e // delete parent folder _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteFolder", path, 0) if err != nil { - return false, fmt.Errorf("error deleting task folder %s: %v", path, getTaskSchedulerError(err)) + return false, fmt.Errorf("error deleting task folder %s: %w", path, getTaskSchedulerError(err)) } return true, nil @@ -651,7 +650,7 @@ func (t *TaskService) DeleteTask(path string) error { _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteTask", path, 0) if err != nil { - return fmt.Errorf("error deleting task %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error deleting task %s: %w", path, getTaskSchedulerError(err)) } return nil diff --git a/manage_test.go b/manage_test.go index bf6b912..c8ec949 100644 --- a/manage_test.go +++ b/manage_test.go @@ -1,28 +1,24 @@ +//go:build windows // +build windows package taskmaster import ( + "errors" "strings" "testing" "time" + + "github.com/rickb777/date/period" ) func TestLocalConnect(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - taskService.Disconnect() + setupTaskService(t) } func TestCreateTask(t *testing.T) { var err error - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - defer taskService.Disconnect() + taskService := setupTaskService(t) // test ExecAction execTaskDef := taskService.NewTaskDefinition() @@ -30,11 +26,22 @@ func TestCreateTask(t *testing.T) { Path: "calc.exe", } execTaskDef.AddAction(popCalc) + assertCalcAction := func(task RegisteredTask) { + requireActionCount(t, task, 1) + action := requireActionAt[ExecAction](t, task, 0) + if action.Path != popCalc.Path { + t.Fatalf("expected exec action path %s, got %s", popCalc.Path, action.Path) + } + } - _, _, err = taskService.CreateTask("\\Taskmaster\\ExecAction", execTaskDef, true) + _, _, err = taskService.CreateTask(testTaskPath("ExecAction"), execTaskDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("ExecAction"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 0) + }) // test ComHandlerAction comHandlerDef := taskService.NewTaskDefinition() @@ -42,33 +49,59 @@ func TestCreateTask(t *testing.T) { ClassID: "{F0001111-0000-0000-0000-0000FEEDACDC}", }) - _, _, err = taskService.CreateTask("\\Taskmaster\\ComHandlerAction", comHandlerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("ComHandlerAction"), comHandlerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("ComHandlerAction"), func(task RegisteredTask) { + requireActionCount(t, task, 1) + action := requireActionAt[ComHandlerAction](t, task, 0) + if action.ClassID != "{F0001111-0000-0000-0000-0000FEEDACDC}" { + t.Fatalf("unexpected class ID %s", action.ClassID) + } + requireTriggerCount(t, task, 0) + }) // test BootTrigger bootTriggerDef := taskService.NewTaskDefinition() bootTriggerDef.AddAction(popCalc) bootTriggerDef.AddTrigger(BootTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\BootTrigger", bootTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("BootTrigger"), bootTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("BootTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[BootTrigger](t, task, 0) + }) // test DailyTrigger dailyTriggerDef := taskService.NewTaskDefinition() dailyTriggerDef.AddAction(popCalc) + dailyRandomDelay := period.NewHMS(0, 15, 0) dailyTriggerDef.AddTrigger(DailyTrigger{ DayInterval: EveryDay, + RandomDelay: dailyRandomDelay, TaskTrigger: TaskTrigger{ StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\DailyTrigger", dailyTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("DailyTrigger"), dailyTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("DailyTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[DailyTrigger](t, task, 0) + if trigger.DayInterval != EveryDay { + t.Fatalf("expected DayInterval %v, got %v", EveryDay, trigger.DayInterval) + } + if trigger.RandomDelay.String() != dailyRandomDelay.String() { + t.Fatalf("expected random delay %s, got %s", dailyRandomDelay, trigger.RandomDelay) + } + }) // test EventTrigger eventTriggerDef := taskService.NewTaskDefinition() @@ -77,28 +110,46 @@ func TestCreateTask(t *testing.T) { eventTriggerDef.AddTrigger(EventTrigger{ Subscription: subscription, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\EventTrigger", eventTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("EventTrigger"), eventTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("EventTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[EventTrigger](t, task, 0) + if trigger.Subscription != subscription { + t.Fatalf("expected subscription %s, got %s", subscription, trigger.Subscription) + } + }) // test IdleTrigger idleTriggerDef := taskService.NewTaskDefinition() idleTriggerDef.AddAction(popCalc) idleTriggerDef.AddTrigger(IdleTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\IdleTrigger", idleTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("IdleTrigger"), idleTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("IdleTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[IdleTrigger](t, task, 0) + }) // test LogonTrigger logonTriggerDef := taskService.NewTaskDefinition() logonTriggerDef.AddAction(popCalc) logonTriggerDef.AddTrigger(LogonTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\LogonTrigger", logonTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("LogonTrigger"), logonTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("LogonTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[LogonTrigger](t, task, 0) + }) // test MonthlyDOWTrigger monthlyDOWTriggerDef := taskService.NewTaskDefinition() @@ -111,10 +162,18 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\MonthlyDOWTrigger", monthlyDOWTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("MonthlyDOWTrigger"), monthlyDOWTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("MonthlyDOWTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[MonthlyDOWTrigger](t, task, 0) + if trigger.DaysOfWeek != Monday|Friday || trigger.MonthsOfYear != January|February || trigger.WeeksOfMonth != First { + t.Fatal("monthly DOW trigger values did not round-trip") + } + }) // test MonthlyTrigger monthlyTriggerDef := taskService.NewTaskDefinition() @@ -126,19 +185,32 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\MonthlyTrigger", monthlyTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("MonthlyTrigger"), monthlyTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("MonthlyTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[MonthlyTrigger](t, task, 0) + if trigger.DaysOfMonth != 3 || trigger.MonthsOfYear != February|March { + t.Fatal("monthly trigger values did not round-trip") + } + }) // test RegistrationTrigger registrationTriggerDef := taskService.NewTaskDefinition() registrationTriggerDef.AddAction(popCalc) registrationTriggerDef.AddTrigger(RegistrationTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\RegistrationTrigger", registrationTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("RegistrationTrigger"), registrationTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("RegistrationTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[RegistrationTrigger](t, task, 0) + }) // test SessionStateChangeTrigger sessionStateChangeTriggerDef := taskService.NewTaskDefinition() @@ -146,23 +218,55 @@ func TestCreateTask(t *testing.T) { sessionStateChangeTriggerDef.AddTrigger(SessionStateChangeTrigger{ StateChange: TASK_SESSION_LOCK, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\SessionStateChangeTrigger", sessionStateChangeTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("SessionStateChangeTrigger"), sessionStateChangeTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("SessionStateChangeTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[SessionStateChangeTrigger](t, task, 0) + if trigger.StateChange != TASK_SESSION_LOCK { + t.Fatalf("expected session state change %d, got %d", TASK_SESSION_LOCK, trigger.StateChange) + } + }) // test TimeTrigger timeTriggerDef := taskService.NewTaskDefinition() timeTriggerDef.AddAction(popCalc) + repetitionInterval := period.NewHMS(0, 30, 0) + repetitionDuration := period.NewHMS(2, 0, 0) timeTriggerDef.AddTrigger(TimeTrigger{ TaskTrigger: TaskTrigger{ StartBoundary: time.Now(), + RepetitionPattern: RepetitionPattern{ + RepetitionInterval: repetitionInterval, + RepetitionDuration: repetitionDuration, + StopAtDurationEnd: true, + }, }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\TimeTrigger", timeTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("TimeTrigger"), timeTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("TimeTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[TimeTrigger](t, task, 0) + if trigger.TaskTrigger.StartBoundary.IsZero() { + t.Fatal("expected time trigger to have a start boundary") + } + if trigger.TaskTrigger.RepetitionInterval.String() != repetitionInterval.String() { + t.Fatalf("expected repetition interval %s, got %s", repetitionInterval, trigger.TaskTrigger.RepetitionInterval) + } + if trigger.TaskTrigger.RepetitionDuration.String() != repetitionDuration.String() { + t.Fatalf("expected repetition duration %s, got %s", repetitionDuration, trigger.TaskTrigger.RepetitionDuration) + } + if !trigger.TaskTrigger.StopAtDurationEnd { + t.Fatal("expected StopAtDurationEnd to be true") + } + }) // test WeeklyTrigger weeklyTriggerDef := taskService.NewTaskDefinition() @@ -174,13 +278,21 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\WeeklyTrigger", weeklyTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("WeeklyTrigger"), weeklyTriggerDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("WeeklyTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[WeeklyTrigger](t, task, 0) + if trigger.DaysOfWeek != Tuesday|Thursday || trigger.WeekInterval != EveryOtherWeek { + t.Fatal("weekly trigger values did not round-trip") + } + }) // test trying to create task where a task at the same path already exists and the 'overwrite' is set to false - _, taskCreated, err := taskService.CreateTask("\\Taskmaster\\TimeTrigger", timeTriggerDef, false) + _, taskCreated, err := taskService.CreateTask(testTaskPath("TimeTrigger"), timeTriggerDef, false) if err != nil { t.Fatal(err) } @@ -190,20 +302,16 @@ func TestCreateTask(t *testing.T) { } func TestUpdateTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() testTask.Definition.RegistrationInfo.Author = "Big Chungus" - _, err = taskService.UpdateTask("\\Taskmaster\\TestTask", testTask.Definition) + _, err := taskService.UpdateTask(testTaskPath("TestTask"), testTask.Definition) if err != nil { t.Fatal(err) } - testTask, err = taskService.GetRegisteredTask("\\Taskmaster\\TestTask") + testTask, err = taskService.GetRegisteredTask(testTaskPath("TestTask")) if err != nil { t.Fatal(err) } @@ -213,47 +321,96 @@ func TestUpdateTask(t *testing.T) { } func TestGetRegisteredTasks(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - defer taskService.Disconnect() + taskService := setupTaskService(t) + createTestTask(taskService) rtc, err := taskService.GetRegisteredTasks() if err != nil { t.Fatal(err) } - rtc.Release() + + var found bool + for _, task := range rtc { + if task.Path == testTaskPath("TestTask") { + found = true + break + } + } + if !found { + t.Fatalf("expected to find %s in registered tasks", testTaskPath("TestTask")) + } } func TestGetTaskFolders(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) + taskService := setupTaskService(t) + + for _, leaf := range []struct { + folder []string + task string + }{ + {folder: []string{"Folders", "Alpha"}, task: "TaskOne"}, + {folder: []string{"Folders", "Beta"}, task: "TaskOne"}, + } { + def := taskService.NewTaskDefinition() + def.AddAction(ExecAction{Path: "calc.exe"}) + + pathParts := append([]string{}, leaf.folder...) + pathParts = append(pathParts, leaf.task) + + if _, _, err := taskService.CreateTask(testTaskPath(pathParts...), def, true); err != nil { + t.Fatalf("failed to seed task %v: %v", pathParts, err) + } } - defer taskService.Disconnect() tf, err := taskService.GetTaskFolders() if err != nil { t.Fatal(err) } - tf.Release() + defer tf.Release() + + var foundTestRoot bool + for _, folder := range tf.SubFolders { + if folder.Path != testTaskRoot { + continue + } + + foundTestRoot = true + queue := append([]*TaskFolder{}, folder.SubFolders...) + leafTasks := map[string]int{} + for len(queue) > 0 { + current := queue[0] + queue = queue[1:] + + if len(current.SubFolders) == 0 { + leafTasks[current.Path] = len(current.RegisteredTasks) + continue + } + + queue = append(queue, current.SubFolders...) + } + + if leafTasks[testTaskPath("Folders", "Alpha")] != 1 || leafTasks[testTaskPath("Folders", "Beta")] != 1 { + t.Fatalf("missing expected leaves or wrong task counts: %v", leafTasks) + } + + break + } + + if !foundTestRoot { + t.Fatalf("did not find %s in folder tree", testTaskRoot) + } } func TestDeleteTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) createTestTask(taskService) - defer taskService.Disconnect() - err = taskService.DeleteTask("\\Taskmaster\\TestTask") + err := taskService.DeleteTask(testTaskPath("TestTask")) if err != nil { t.Fatal(err) } - deletedTask, err := taskService.GetRegisteredTask("\\Taskmaster\\TestTask") + deletedTask, err := taskService.GetRegisteredTask(testTaskPath("TestTask")) if err == nil { t.Fatal("task shouldn't still exist") } @@ -261,15 +418,11 @@ func TestDeleteTask(t *testing.T) { } func TestDeleteFolder(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) createTestTask(taskService) - defer taskService.Disconnect() var folderDeleted bool - folderDeleted, err = taskService.DeleteFolder("\\Taskmaster", false) + folderDeleted, err := taskService.DeleteFolder(testTaskRoot, false) if err != nil { t.Fatal(err) } @@ -277,7 +430,7 @@ func TestDeleteFolder(t *testing.T) { t.Error("folder shouldn't have been deleted") } - folderDeleted, err = taskService.DeleteFolder("\\Taskmaster", true) + folderDeleted, err = taskService.DeleteFolder(testTaskRoot, true) if err != nil { t.Fatal(err) } @@ -289,7 +442,7 @@ func TestDeleteFolder(t *testing.T) { if err != nil { t.Fatal(err) } - taskmasterFolder, err := taskService.GetTaskFolder("\\Taskmaster") + taskmasterFolder, err := taskService.GetTaskFolder(testTaskRoot) if err == nil { t.Fatal("folder shouldn't exist") } @@ -297,8 +450,95 @@ func TestDeleteFolder(t *testing.T) { t.Error("folder struct should be defaultly constructed") } for _, task := range tasks { - if strings.Split(task.Path, "\\")[1] == "Taskmaster" { + if strings.Split(task.Path, "\\")[1] == testTaskFolderName { t.Error("task should've been deleted") } } } + +func TestConnectWithOptionsInvalidTarget(t *testing.T) { + _, err := ConnectWithOptions("invalid-taskmaster-host", "", "", "") + if err == nil { + t.Fatal("expected connection failure") + } + if !errors.Is(err, ErrConnectionFailure) { + t.Fatalf("expected ErrConnectionFailure, got %v", err) + } +} + +func TestPrincipalSettingsRoundTrip(t *testing.T) { + taskService := setupTaskService(t) + + connectedDomain := taskService.GetConnectedDomain() + connectedUser := taskService.GetConnectedUser() + interactiveUserID := connectedUser + if connectedDomain != "" { + interactiveUserID = connectedDomain + `\` + connectedUser + } + + testPrincipals := []struct { + name string + principal Principal + }{ + { + name: "Interactive", + principal: Principal{ + UserID: interactiveUserID, + LogonType: TASK_LOGON_INTERACTIVE_TOKEN, + RunLevel: TASK_RUNLEVEL_HIGHEST, + }, + }, + { + name: "System", + principal: Principal{ + UserID: "SYSTEM", + LogonType: TASK_LOGON_SERVICE_ACCOUNT, + RunLevel: TASK_RUNLEVEL_HIGHEST, + }, + }, + } + + for _, tt := range testPrincipals { + def := taskService.NewTaskDefinition() + def.Actions = nil + def.AddAction(ExecAction{Path: "calc.exe"}) + def.Principal = tt.principal + def.Settings.MultipleInstances = TASK_INSTANCES_QUEUE + def.Settings.StopIfGoingOnBatteries = false + + path := testTaskPath("Principal", tt.name) + if _, _, err := taskService.CreateTask(path, def, true); err != nil { + if strings.Contains(err.Error(), "Access is denied") { + if tt.name == "System" { + t.Logf("skipping system principal test due to insufficient privileges: %v", err) + continue + } + t.Skipf("skipping principal test for %s: %v", tt.name, err) + } + t.Fatalf("failed to create task for %s: %v", tt.name, err) + } + + withRegisteredTask(t, taskService, path, func(task RegisteredTask) { + got := task.Definition.Principal + if tt.name == "Interactive" { + if !strings.EqualFold(got.UserID, interactiveUserID) && !strings.EqualFold(got.UserID, connectedUser) { + t.Fatalf("principal %s: expected UserID %s or %s, got %s", tt.name, interactiveUserID, connectedUser, got.UserID) + } + } else if got.UserID != tt.principal.UserID { + t.Fatalf("principal %s: expected UserID %s, got %s", tt.name, tt.principal.UserID, got.UserID) + } + if got.LogonType != tt.principal.LogonType { + t.Fatalf("principal %s: expected LogonType %d, got %d", tt.name, tt.principal.LogonType, got.LogonType) + } + if got.RunLevel != tt.principal.RunLevel { + t.Fatalf("principal %s: expected RunLevel %d, got %d", tt.name, tt.principal.RunLevel, got.RunLevel) + } + if task.Definition.Settings.MultipleInstances != TASK_INSTANCES_QUEUE { + t.Fatalf("principal %s: expected MultipleInstances %d, got %d", tt.name, TASK_INSTANCES_QUEUE, task.Definition.Settings.MultipleInstances) + } + if task.Definition.Settings.StopIfGoingOnBatteries { + t.Fatalf("principal %s: expected StopIfGoingOnBatteries false", tt.name) + } + }) + } +} diff --git a/parse.go b/parse.go index e236ee8..4bbec6b 100644 --- a/parse.go +++ b/parse.go @@ -6,6 +6,7 @@ package taskmaster import ( "errors" "fmt" + "math" "time" ole "github.com/go-ole/go-ole" @@ -90,13 +91,13 @@ func parseRegisteredTask(task *ole.IDispatch) (RegisteredTask, string, error) { if err != nil { return RegisteredTask{}, "", err } - nextRunTime := nextRunTimeVar.Value().(time.Time) + nextRunTime := variantTimeOrZero(nextRunTimeVar) lastRunTimeVar, err := oleutil.GetProperty(task, "LastRunTime") if err != nil { return RegisteredTask{}, "", err } - lastRunTime := lastRunTimeVar.Value().(time.Time) + lastRunTime := variantTimeOrZero(lastRunTimeVar) lastTaskResultVar, err := oleutil.GetProperty(task, "LastTaskResult") if err != nil { @@ -614,3 +615,26 @@ func parseTaskTrigger(trigger *ole.IDispatch) (Trigger, error) { return nil, errors.New("unsupported ITrigger type") } } + +var oleAutomationEpoch = time.Date(1899, time.December, 30, 0, 0, 0, 0, time.UTC) + +func variantTimeOrZero(v *ole.VARIANT) time.Time { + if v == nil || v.VT != ole.VT_DATE { + return time.Time{} + } + + return oleDateToTime(math.Float64frombits(uint64(v.Val))) +} + +func oleDateToTime(value float64) time.Time { + if value == 0 || math.IsNaN(value) || math.IsInf(value, 0) { + return time.Time{} + } + + const day = 24 * time.Hour + days, frac := math.Modf(value) + dayDuration := time.Duration(int64(days)) * day + fracDuration := time.Duration(frac * float64(day)) + + return oleAutomationEpoch.Add(dayDuration + fracDuration) +} diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..a52b78c --- /dev/null +++ b/parse_test.go @@ -0,0 +1,29 @@ +//go:build windows +// +build windows + +package taskmaster + +import ( + "math" + "testing" + "time" + + ole "github.com/go-ole/go-ole" +) + +func TestVariantTimeOrZero(t *testing.T) { + if got := variantTimeOrZero(nil); !got.IsZero() { + t.Fatalf("expected zero time for nil variant, got %v", got) + } + + if got := variantTimeOrZero(&ole.VARIANT{VT: ole.VT_I4, Val: 10}); !got.IsZero() { + t.Fatalf("expected zero time for non-date variant, got %v", got) + } + + vtDate := &ole.VARIANT{VT: ole.VT_DATE, Val: int64(math.Float64bits(2.5))} + got := variantTimeOrZero(vtDate) + expected := time.Date(1900, time.January, 1, 12, 0, 0, 0, time.UTC) + if !got.Equal(expected) { + t.Fatalf("expected %v, got %v", expected, got) + } +} diff --git a/tasks_test.go b/tasks_test.go index bbb3d03..b53448a 100644 --- a/tasks_test.go +++ b/tasks_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package taskmaster @@ -13,12 +14,8 @@ func TestRelease(t *testing.T) { } func TestRunRegisteredTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTask, err := testTask.Run("3") if err != nil { @@ -28,12 +25,8 @@ func TestRunRegisteredTask(t *testing.T) { } func TestRefreshRunningTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTask, err := testTask.Run("3") if err != nil { @@ -48,12 +41,8 @@ func TestRefreshRunningTask(t *testing.T) { } func TestStopRunningTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTask, err := testTask.Run("9001") if err != nil { @@ -67,14 +56,11 @@ func TestStopRunningTask(t *testing.T) { } func TestGetInstancesRegisteredTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTasks := make(RunningTaskCollection, 5, 5) + var err error // create a few running tasks so that there will be multiple instances // of the registered task running @@ -100,15 +86,12 @@ func TestGetInstancesRegisteredTask(t *testing.T) { } func TestStopRegisteredTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() + var err error for i := 0; i < 5; i++ { - _, err := testTask.Run("3") + _, err = testTask.Run("3") if err != nil { t.Fatal(err) } @@ -119,3 +102,40 @@ func TestStopRegisteredTask(t *testing.T) { t.Fatalf("error stopping tasks: %v", err) } } + +func TestGetRunningTasksServiceWide(t *testing.T) { + taskService := setupTaskService(t) + testTask := createTestTask(taskService) + + runningInstances := make([]RunningTask, 0, 3) + for i := 0; i < 3; i++ { + instance, err := testTask.Run("5") + if err != nil { + t.Fatalf("failed to run task instance %d: %v", i, err) + } + runningInstances = append(runningInstances, instance) + time.Sleep(100 * time.Millisecond) + } + + serviceRunningTasks, err := taskService.GetRunningTasks() + if err != nil { + t.Fatalf("failed to get running tasks: %v", err) + } + defer serviceRunningTasks.Release() + + var seen int + for _, runningTask := range serviceRunningTasks { + if runningTask.Path == testTask.Path { + seen++ + } + } + + if seen != len(runningInstances) { + t.Fatalf("expected %d running entries for %s, got %d", len(runningInstances), testTask.Path, seen) + } + + for _, runningTask := range runningInstances { + runningTask.Release() + } + _ = testTask.Stop() +} diff --git a/testing_utils.go b/testing_utils.go index 3f0961d..42f9d49 100644 --- a/testing_utils.go +++ b/testing_utils.go @@ -1,8 +1,60 @@ +//go:build windows // +build windows package taskmaster -func createTestTask(taskSvc TaskService) RegisteredTask { +import ( + "strings" + "testing" +) + +const ( + testTaskFolderName = "TaskmasterTests" + testTaskRoot = `\` + testTaskFolderName +) + +func setupTaskService(t *testing.T) *TaskService { + t.Helper() + + taskService, err := Connect() + if err != nil { + t.Fatalf("failed to connect to Task Scheduler: %v", err) + } + + resetTestFolder(t, &taskService) + + t.Cleanup(func() { + resetTestFolder(t, &taskService) + taskService.Disconnect() + }) + + return &taskService +} + +func resetTestFolder(t *testing.T, taskService *TaskService) { + t.Helper() + + if taskService.taskFolderExist(testTaskRoot) { + if _, err := taskService.DeleteFolder(testTaskRoot, true); err != nil { + t.Fatalf("failed to delete %s: %v", testTaskRoot, err) + } + } +} + +func testTaskPath(parts ...string) string { + if len(parts) == 0 { + return testTaskRoot + } + + cleaned := make([]string, 0, len(parts)) + for _, part := range parts { + cleaned = append(cleaned, strings.Trim(part, "\\")) + } + + return testTaskRoot + `\` + strings.Join(cleaned, `\`) +} + +func createTestTask(taskSvc *TaskService) RegisteredTask { newTaskDef := taskSvc.NewTaskDefinition() newTaskDef.AddAction(ExecAction{ Path: "cmd.exe", @@ -10,10 +62,70 @@ func createTestTask(taskSvc TaskService) RegisteredTask { }) newTaskDef.Settings.MultipleInstances = TASK_INSTANCES_PARALLEL - task, _, err := taskSvc.CreateTask("\\Taskmaster\\TestTask", newTaskDef, true) + task, _, err := taskSvc.CreateTask(testTaskPath("TestTask"), newTaskDef, true) if err != nil { panic(err) } return task } + +func withRegisteredTask(t *testing.T, taskSvc *TaskService, path string, fn func(RegisteredTask)) { + t.Helper() + + task, err := taskSvc.GetRegisteredTask(path) + if err != nil { + t.Fatalf("failed to get registered task %s: %v", path, err) + } + defer task.Release() + + fn(task) +} + +func requireActionCount(t *testing.T, task RegisteredTask, expected int) { + t.Helper() + + if len(task.Definition.Actions) != expected { + t.Fatalf("expected %d actions, got %d", expected, len(task.Definition.Actions)) + } +} + +func requireTriggerCount(t *testing.T, task RegisteredTask, expected int) { + t.Helper() + + if len(task.Definition.Triggers) != expected { + t.Fatalf("expected %d triggers, got %d", expected, len(task.Definition.Triggers)) + } +} + +func requireActionAt[T Action](t *testing.T, task RegisteredTask, idx int) T { + t.Helper() + + if idx >= len(task.Definition.Actions) { + t.Fatalf("expected action at index %d, only %d actions available", idx, len(task.Definition.Actions)) + } + + action, ok := task.Definition.Actions[idx].(T) + if !ok { + var zero T + t.Fatalf("expected action %T at index %d, got %T", zero, idx, task.Definition.Actions[idx]) + } + + return action +} + +func requireTriggerAt[T Trigger](t *testing.T, task RegisteredTask, idx int) T { + t.Helper() + + if idx >= len(task.Definition.Triggers) { + t.Fatalf("expected trigger at index %d, only %d triggers available", idx, len(task.Definition.Triggers)) + } + + trigger, ok := task.Definition.Triggers[idx].(T) + if !ok { + var zero T + t.Fatalf("expected trigger %T at index %d, got %T", zero, idx, task.Definition.Triggers[idx]) + } + + return trigger +}