From eabd53670788880531d53c5eb32087d3cf85aafe Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 10:48:29 +0100 Subject: [PATCH 01/11] improve and automate tests --- .github/workflows/ci.yml | 16 +++++++ manage_test.go | 100 ++++++++++++++++----------------------- parse.go | 28 ++++++++++- tasks_test.go | 35 ++++---------- testing_utils.go | 56 +++++++++++++++++++++- 5 files changed, 147 insertions(+), 88 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0c15c83 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,16 @@ +name: CI + +on: + push: + pull_request: + +jobs: + test: + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: Run tests + run: go test -p 1 ./... diff --git a/manage_test.go b/manage_test.go index bf6b912..25f1dd3 100644 --- a/manage_test.go +++ b/manage_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package taskmaster @@ -9,20 +10,12 @@ import ( ) func TestLocalConnect(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - taskService.Disconnect() + setupTaskService(t) } func TestCreateTask(t *testing.T) { var err error - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - defer taskService.Disconnect() + taskService := setupTaskService(t) // test ExecAction execTaskDef := taskService.NewTaskDefinition() @@ -31,7 +24,7 @@ func TestCreateTask(t *testing.T) { } execTaskDef.AddAction(popCalc) - _, _, err = taskService.CreateTask("\\Taskmaster\\ExecAction", execTaskDef, true) + _, _, err = taskService.CreateTask(testTaskPath("ExecAction"), execTaskDef, true) if err != nil { t.Fatal(err) } @@ -42,7 +35,7 @@ func TestCreateTask(t *testing.T) { ClassID: "{F0001111-0000-0000-0000-0000FEEDACDC}", }) - _, _, err = taskService.CreateTask("\\Taskmaster\\ComHandlerAction", comHandlerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("ComHandlerAction"), comHandlerDef, true) if err != nil { t.Fatal(err) } @@ -51,7 +44,7 @@ func TestCreateTask(t *testing.T) { bootTriggerDef := taskService.NewTaskDefinition() bootTriggerDef.AddAction(popCalc) bootTriggerDef.AddTrigger(BootTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\BootTrigger", bootTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("BootTrigger"), bootTriggerDef, true) if err != nil { t.Fatal(err) } @@ -65,7 +58,7 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\DailyTrigger", dailyTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("DailyTrigger"), dailyTriggerDef, true) if err != nil { t.Fatal(err) } @@ -77,7 +70,7 @@ func TestCreateTask(t *testing.T) { eventTriggerDef.AddTrigger(EventTrigger{ Subscription: subscription, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\EventTrigger", eventTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("EventTrigger"), eventTriggerDef, true) if err != nil { t.Fatal(err) } @@ -86,7 +79,7 @@ func TestCreateTask(t *testing.T) { idleTriggerDef := taskService.NewTaskDefinition() idleTriggerDef.AddAction(popCalc) idleTriggerDef.AddTrigger(IdleTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\IdleTrigger", idleTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("IdleTrigger"), idleTriggerDef, true) if err != nil { t.Fatal(err) } @@ -95,7 +88,7 @@ func TestCreateTask(t *testing.T) { logonTriggerDef := taskService.NewTaskDefinition() logonTriggerDef.AddAction(popCalc) logonTriggerDef.AddTrigger(LogonTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\LogonTrigger", logonTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("LogonTrigger"), logonTriggerDef, true) if err != nil { t.Fatal(err) } @@ -111,7 +104,7 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\MonthlyDOWTrigger", monthlyDOWTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("MonthlyDOWTrigger"), monthlyDOWTriggerDef, true) if err != nil { t.Fatal(err) } @@ -126,7 +119,7 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\MonthlyTrigger", monthlyTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("MonthlyTrigger"), monthlyTriggerDef, true) if err != nil { t.Fatal(err) } @@ -135,7 +128,7 @@ func TestCreateTask(t *testing.T) { registrationTriggerDef := taskService.NewTaskDefinition() registrationTriggerDef.AddAction(popCalc) registrationTriggerDef.AddTrigger(RegistrationTrigger{}) - _, _, err = taskService.CreateTask("\\Taskmaster\\RegistrationTrigger", registrationTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("RegistrationTrigger"), registrationTriggerDef, true) if err != nil { t.Fatal(err) } @@ -146,7 +139,7 @@ func TestCreateTask(t *testing.T) { sessionStateChangeTriggerDef.AddTrigger(SessionStateChangeTrigger{ StateChange: TASK_SESSION_LOCK, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\SessionStateChangeTrigger", sessionStateChangeTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("SessionStateChangeTrigger"), sessionStateChangeTriggerDef, true) if err != nil { t.Fatal(err) } @@ -159,7 +152,7 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\TimeTrigger", timeTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("TimeTrigger"), timeTriggerDef, true) if err != nil { t.Fatal(err) } @@ -174,13 +167,13 @@ func TestCreateTask(t *testing.T) { StartBoundary: time.Now(), }, }) - _, _, err = taskService.CreateTask("\\Taskmaster\\WeeklyTrigger", weeklyTriggerDef, true) + _, _, err = taskService.CreateTask(testTaskPath("WeeklyTrigger"), weeklyTriggerDef, true) if err != nil { t.Fatal(err) } // test trying to create task where a task at the same path already exists and the 'overwrite' is set to false - _, taskCreated, err := taskService.CreateTask("\\Taskmaster\\TimeTrigger", timeTriggerDef, false) + _, taskCreated, err := taskService.CreateTask(testTaskPath("TimeTrigger"), timeTriggerDef, false) if err != nil { t.Fatal(err) } @@ -190,20 +183,16 @@ func TestCreateTask(t *testing.T) { } func TestUpdateTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() testTask.Definition.RegistrationInfo.Author = "Big Chungus" - _, err = taskService.UpdateTask("\\Taskmaster\\TestTask", testTask.Definition) + _, err := taskService.UpdateTask(testTaskPath("TestTask"), testTask.Definition) if err != nil { t.Fatal(err) } - testTask, err = taskService.GetRegisteredTask("\\Taskmaster\\TestTask") + testTask, err = taskService.GetRegisteredTask(testTaskPath("TestTask")) if err != nil { t.Fatal(err) } @@ -213,25 +202,28 @@ func TestUpdateTask(t *testing.T) { } func TestGetRegisteredTasks(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - defer taskService.Disconnect() + taskService := setupTaskService(t) + createTestTask(taskService) rtc, err := taskService.GetRegisteredTasks() if err != nil { t.Fatal(err) } - rtc.Release() + + var found bool + for _, task := range rtc { + if task.Path == testTaskPath("TestTask") { + found = true + break + } + } + if !found { + t.Fatalf("expected to find %s in registered tasks", testTaskPath("TestTask")) + } } func TestGetTaskFolders(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } - defer taskService.Disconnect() + taskService := setupTaskService(t) tf, err := taskService.GetTaskFolders() if err != nil { @@ -241,19 +233,15 @@ func TestGetTaskFolders(t *testing.T) { } func TestDeleteTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) createTestTask(taskService) - defer taskService.Disconnect() - err = taskService.DeleteTask("\\Taskmaster\\TestTask") + err := taskService.DeleteTask(testTaskPath("TestTask")) if err != nil { t.Fatal(err) } - deletedTask, err := taskService.GetRegisteredTask("\\Taskmaster\\TestTask") + deletedTask, err := taskService.GetRegisteredTask(testTaskPath("TestTask")) if err == nil { t.Fatal("task shouldn't still exist") } @@ -261,15 +249,11 @@ func TestDeleteTask(t *testing.T) { } func TestDeleteFolder(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) createTestTask(taskService) - defer taskService.Disconnect() var folderDeleted bool - folderDeleted, err = taskService.DeleteFolder("\\Taskmaster", false) + folderDeleted, err := taskService.DeleteFolder(testTaskRoot, false) if err != nil { t.Fatal(err) } @@ -277,7 +261,7 @@ func TestDeleteFolder(t *testing.T) { t.Error("folder shouldn't have been deleted") } - folderDeleted, err = taskService.DeleteFolder("\\Taskmaster", true) + folderDeleted, err = taskService.DeleteFolder(testTaskRoot, true) if err != nil { t.Fatal(err) } @@ -289,7 +273,7 @@ func TestDeleteFolder(t *testing.T) { if err != nil { t.Fatal(err) } - taskmasterFolder, err := taskService.GetTaskFolder("\\Taskmaster") + taskmasterFolder, err := taskService.GetTaskFolder(testTaskRoot) if err == nil { t.Fatal("folder shouldn't exist") } @@ -297,7 +281,7 @@ func TestDeleteFolder(t *testing.T) { t.Error("folder struct should be defaultly constructed") } for _, task := range tasks { - if strings.Split(task.Path, "\\")[1] == "Taskmaster" { + if strings.Split(task.Path, "\\")[1] == testTaskFolderName { t.Error("task should've been deleted") } } diff --git a/parse.go b/parse.go index e236ee8..4bbec6b 100644 --- a/parse.go +++ b/parse.go @@ -6,6 +6,7 @@ package taskmaster import ( "errors" "fmt" + "math" "time" ole "github.com/go-ole/go-ole" @@ -90,13 +91,13 @@ func parseRegisteredTask(task *ole.IDispatch) (RegisteredTask, string, error) { if err != nil { return RegisteredTask{}, "", err } - nextRunTime := nextRunTimeVar.Value().(time.Time) + nextRunTime := variantTimeOrZero(nextRunTimeVar) lastRunTimeVar, err := oleutil.GetProperty(task, "LastRunTime") if err != nil { return RegisteredTask{}, "", err } - lastRunTime := lastRunTimeVar.Value().(time.Time) + lastRunTime := variantTimeOrZero(lastRunTimeVar) lastTaskResultVar, err := oleutil.GetProperty(task, "LastTaskResult") if err != nil { @@ -614,3 +615,26 @@ func parseTaskTrigger(trigger *ole.IDispatch) (Trigger, error) { return nil, errors.New("unsupported ITrigger type") } } + +var oleAutomationEpoch = time.Date(1899, time.December, 30, 0, 0, 0, 0, time.UTC) + +func variantTimeOrZero(v *ole.VARIANT) time.Time { + if v == nil || v.VT != ole.VT_DATE { + return time.Time{} + } + + return oleDateToTime(math.Float64frombits(uint64(v.Val))) +} + +func oleDateToTime(value float64) time.Time { + if value == 0 || math.IsNaN(value) || math.IsInf(value, 0) { + return time.Time{} + } + + const day = 24 * time.Hour + days, frac := math.Modf(value) + dayDuration := time.Duration(int64(days)) * day + fracDuration := time.Duration(frac * float64(day)) + + return oleAutomationEpoch.Add(dayDuration + fracDuration) +} diff --git a/tasks_test.go b/tasks_test.go index bbb3d03..a9918f8 100644 --- a/tasks_test.go +++ b/tasks_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package taskmaster @@ -13,12 +14,8 @@ func TestRelease(t *testing.T) { } func TestRunRegisteredTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTask, err := testTask.Run("3") if err != nil { @@ -28,12 +25,8 @@ func TestRunRegisteredTask(t *testing.T) { } func TestRefreshRunningTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTask, err := testTask.Run("3") if err != nil { @@ -48,12 +41,8 @@ func TestRefreshRunningTask(t *testing.T) { } func TestStopRunningTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTask, err := testTask.Run("9001") if err != nil { @@ -67,14 +56,11 @@ func TestStopRunningTask(t *testing.T) { } func TestGetInstancesRegisteredTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() runningTasks := make(RunningTaskCollection, 5, 5) + var err error // create a few running tasks so that there will be multiple instances // of the registered task running @@ -100,15 +86,12 @@ func TestGetInstancesRegisteredTask(t *testing.T) { } func TestStopRegisteredTask(t *testing.T) { - taskService, err := Connect() - if err != nil { - t.Fatal(err) - } + taskService := setupTaskService(t) testTask := createTestTask(taskService) - defer taskService.Disconnect() + var err error for i := 0; i < 5; i++ { - _, err := testTask.Run("3") + _, err = testTask.Run("3") if err != nil { t.Fatal(err) } diff --git a/testing_utils.go b/testing_utils.go index 3f0961d..951bdc1 100644 --- a/testing_utils.go +++ b/testing_utils.go @@ -1,8 +1,60 @@ +//go:build windows // +build windows package taskmaster -func createTestTask(taskSvc TaskService) RegisteredTask { +import ( + "strings" + "testing" +) + +const ( + testTaskFolderName = "TaskmasterTests" + testTaskRoot = `\` + testTaskFolderName +) + +func setupTaskService(t *testing.T) *TaskService { + t.Helper() + + taskService, err := Connect() + if err != nil { + t.Fatalf("failed to connect to Task Scheduler: %v", err) + } + + resetTestFolder(t, &taskService) + + t.Cleanup(func() { + resetTestFolder(t, &taskService) + taskService.Disconnect() + }) + + return &taskService +} + +func resetTestFolder(t *testing.T, taskService *TaskService) { + t.Helper() + + if taskService.taskFolderExist(testTaskRoot) { + if _, err := taskService.DeleteFolder(testTaskRoot, true); err != nil { + t.Fatalf("failed to delete %s: %v", testTaskRoot, err) + } + } +} + +func testTaskPath(parts ...string) string { + if len(parts) == 0 { + return testTaskRoot + } + + cleaned := make([]string, 0, len(parts)) + for _, part := range parts { + cleaned = append(cleaned, strings.Trim(part, "\\")) + } + + return testTaskRoot + `\` + strings.Join(cleaned, `\`) +} + +func createTestTask(taskSvc *TaskService) RegisteredTask { newTaskDef := taskSvc.NewTaskDefinition() newTaskDef.AddAction(ExecAction{ Path: "cmd.exe", @@ -10,7 +62,7 @@ func createTestTask(taskSvc TaskService) RegisteredTask { }) newTaskDef.Settings.MultipleInstances = TASK_INSTANCES_PARALLEL - task, _, err := taskSvc.CreateTask("\\Taskmaster\\TestTask", newTaskDef, true) + task, _, err := taskSvc.CreateTask(testTaskPath("TestTask"), newTaskDef, true) if err != nil { panic(err) } From 73ba122009f81464bd0c21860daeffd9f8be6f7a Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 11:08:34 +0100 Subject: [PATCH 02/11] actually inspect created tasks --- manage_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++ testing_utils.go | 60 ++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/manage_test.go b/manage_test.go index 25f1dd3..7f453c7 100644 --- a/manage_test.go +++ b/manage_test.go @@ -23,11 +23,22 @@ func TestCreateTask(t *testing.T) { Path: "calc.exe", } execTaskDef.AddAction(popCalc) + assertCalcAction := func(task RegisteredTask) { + requireActionCount(t, task, 1) + action := requireActionAt[ExecAction](t, task, 0) + if action.Path != popCalc.Path { + t.Fatalf("expected exec action path %s, got %s", popCalc.Path, action.Path) + } + requireTriggerCount(t, task, 0) + } _, _, err = taskService.CreateTask(testTaskPath("ExecAction"), execTaskDef, true) if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("ExecAction"), func(task RegisteredTask) { + assertCalcAction(task) + }) // test ComHandlerAction comHandlerDef := taskService.NewTaskDefinition() @@ -39,6 +50,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("ComHandlerAction"), func(task RegisteredTask) { + requireActionCount(t, task, 1) + action := requireActionAt[ComHandlerAction](t, task, 0) + if action.ClassID != "{F0001111-0000-0000-0000-0000FEEDACDC}" { + t.Fatalf("unexpected class ID %s", action.ClassID) + } + requireTriggerCount(t, task, 0) + }) // test BootTrigger bootTriggerDef := taskService.NewTaskDefinition() @@ -48,6 +67,11 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("BootTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[BootTrigger](t, task, 0) + }) // test DailyTrigger dailyTriggerDef := taskService.NewTaskDefinition() @@ -62,6 +86,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("DailyTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[DailyTrigger](t, task, 0) + if trigger.DayInterval != EveryDay { + t.Fatalf("expected DayInterval %v, got %v", EveryDay, trigger.DayInterval) + } + }) // test EventTrigger eventTriggerDef := taskService.NewTaskDefinition() @@ -74,6 +106,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("EventTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[EventTrigger](t, task, 0) + if trigger.Subscription != subscription { + t.Fatalf("expected subscription %s, got %s", subscription, trigger.Subscription) + } + }) // test IdleTrigger idleTriggerDef := taskService.NewTaskDefinition() @@ -83,6 +123,11 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("IdleTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[IdleTrigger](t, task, 0) + }) // test LogonTrigger logonTriggerDef := taskService.NewTaskDefinition() @@ -92,6 +137,11 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("LogonTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[LogonTrigger](t, task, 0) + }) // test MonthlyDOWTrigger monthlyDOWTriggerDef := taskService.NewTaskDefinition() @@ -108,6 +158,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("MonthlyDOWTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[MonthlyDOWTrigger](t, task, 0) + if trigger.DaysOfWeek != Monday|Friday || trigger.MonthsOfYear != January|February || trigger.WeeksOfMonth != First { + t.Fatal("monthly DOW trigger values did not round-trip") + } + }) // test MonthlyTrigger monthlyTriggerDef := taskService.NewTaskDefinition() @@ -123,6 +181,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("MonthlyTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[MonthlyTrigger](t, task, 0) + if trigger.DaysOfMonth != 3 || trigger.MonthsOfYear != February|March { + t.Fatal("monthly trigger values did not round-trip") + } + }) // test RegistrationTrigger registrationTriggerDef := taskService.NewTaskDefinition() @@ -132,6 +198,11 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("RegistrationTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + requireTriggerAt[RegistrationTrigger](t, task, 0) + }) // test SessionStateChangeTrigger sessionStateChangeTriggerDef := taskService.NewTaskDefinition() @@ -143,6 +214,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("SessionStateChangeTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[SessionStateChangeTrigger](t, task, 0) + if trigger.StateChange != TASK_SESSION_LOCK { + t.Fatalf("expected session state change %d, got %d", TASK_SESSION_LOCK, trigger.StateChange) + } + }) // test TimeTrigger timeTriggerDef := taskService.NewTaskDefinition() @@ -156,6 +235,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("TimeTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[TimeTrigger](t, task, 0) + if trigger.TaskTrigger.StartBoundary.IsZero() { + t.Fatal("expected time trigger to have a start boundary") + } + }) // test WeeklyTrigger weeklyTriggerDef := taskService.NewTaskDefinition() @@ -171,6 +258,14 @@ func TestCreateTask(t *testing.T) { if err != nil { t.Fatal(err) } + withRegisteredTask(t, taskService, testTaskPath("WeeklyTrigger"), func(task RegisteredTask) { + assertCalcAction(task) + requireTriggerCount(t, task, 1) + trigger := requireTriggerAt[WeeklyTrigger](t, task, 0) + if trigger.DaysOfWeek != Tuesday|Thursday || trigger.WeekInterval != EveryOtherWeek { + t.Fatal("weekly trigger values did not round-trip") + } + }) // test trying to create task where a task at the same path already exists and the 'overwrite' is set to false _, taskCreated, err := taskService.CreateTask(testTaskPath("TimeTrigger"), timeTriggerDef, false) diff --git a/testing_utils.go b/testing_utils.go index 951bdc1..42f9d49 100644 --- a/testing_utils.go +++ b/testing_utils.go @@ -69,3 +69,63 @@ func createTestTask(taskSvc *TaskService) RegisteredTask { return task } + +func withRegisteredTask(t *testing.T, taskSvc *TaskService, path string, fn func(RegisteredTask)) { + t.Helper() + + task, err := taskSvc.GetRegisteredTask(path) + if err != nil { + t.Fatalf("failed to get registered task %s: %v", path, err) + } + defer task.Release() + + fn(task) +} + +func requireActionCount(t *testing.T, task RegisteredTask, expected int) { + t.Helper() + + if len(task.Definition.Actions) != expected { + t.Fatalf("expected %d actions, got %d", expected, len(task.Definition.Actions)) + } +} + +func requireTriggerCount(t *testing.T, task RegisteredTask, expected int) { + t.Helper() + + if len(task.Definition.Triggers) != expected { + t.Fatalf("expected %d triggers, got %d", expected, len(task.Definition.Triggers)) + } +} + +func requireActionAt[T Action](t *testing.T, task RegisteredTask, idx int) T { + t.Helper() + + if idx >= len(task.Definition.Actions) { + t.Fatalf("expected action at index %d, only %d actions available", idx, len(task.Definition.Actions)) + } + + action, ok := task.Definition.Actions[idx].(T) + if !ok { + var zero T + t.Fatalf("expected action %T at index %d, got %T", zero, idx, task.Definition.Actions[idx]) + } + + return action +} + +func requireTriggerAt[T Trigger](t *testing.T, task RegisteredTask, idx int) T { + t.Helper() + + if idx >= len(task.Definition.Triggers) { + t.Fatalf("expected trigger at index %d, only %d triggers available", idx, len(task.Definition.Triggers)) + } + + trigger, ok := task.Definition.Triggers[idx].(T) + if !ok { + var zero T + t.Fatalf("expected trigger %T at index %d, got %T", zero, idx, task.Definition.Triggers[idx]) + } + + return trigger +} From 14f128de282efb9257f7ed329377c44522e47f1e Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 11:19:47 +0100 Subject: [PATCH 03/11] test GetRunningTasks --- manage.go | 1 - tasks_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/manage.go b/manage.go index 68b8379..ea0a9f0 100644 --- a/manage.go +++ b/manage.go @@ -135,7 +135,6 @@ func (t *TaskService) GetRunningTasks() (RunningTaskCollection, error) { if err != nil { return nil, fmt.Errorf("error getting running tasks: %v", getTaskSchedulerError(err)) } - defer res.Clear() runningTasksObj := res.ToIDispatch() defer runningTasksObj.Release() diff --git a/tasks_test.go b/tasks_test.go index a9918f8..b53448a 100644 --- a/tasks_test.go +++ b/tasks_test.go @@ -102,3 +102,40 @@ func TestStopRegisteredTask(t *testing.T) { t.Fatalf("error stopping tasks: %v", err) } } + +func TestGetRunningTasksServiceWide(t *testing.T) { + taskService := setupTaskService(t) + testTask := createTestTask(taskService) + + runningInstances := make([]RunningTask, 0, 3) + for i := 0; i < 3; i++ { + instance, err := testTask.Run("5") + if err != nil { + t.Fatalf("failed to run task instance %d: %v", i, err) + } + runningInstances = append(runningInstances, instance) + time.Sleep(100 * time.Millisecond) + } + + serviceRunningTasks, err := taskService.GetRunningTasks() + if err != nil { + t.Fatalf("failed to get running tasks: %v", err) + } + defer serviceRunningTasks.Release() + + var seen int + for _, runningTask := range serviceRunningTasks { + if runningTask.Path == testTask.Path { + seen++ + } + } + + if seen != len(runningInstances) { + t.Fatalf("expected %d running entries for %s, got %d", len(runningInstances), testTask.Path, seen) + } + + for _, runningTask := range runningInstances { + runningTask.Release() + } + _ = testTask.Stop() +} From 05390016a27e3ac9d22b2ae92215dd575cbbf0fd Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 11:26:08 +0100 Subject: [PATCH 04/11] fix test --- manage_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manage_test.go b/manage_test.go index 7f453c7..9375838 100644 --- a/manage_test.go +++ b/manage_test.go @@ -29,7 +29,6 @@ func TestCreateTask(t *testing.T) { if action.Path != popCalc.Path { t.Fatalf("expected exec action path %s, got %s", popCalc.Path, action.Path) } - requireTriggerCount(t, task, 0) } _, _, err = taskService.CreateTask(testTaskPath("ExecAction"), execTaskDef, true) @@ -38,6 +37,7 @@ func TestCreateTask(t *testing.T) { } withRegisteredTask(t, taskService, testTaskPath("ExecAction"), func(task RegisteredTask) { assertCalcAction(task) + requireTriggerCount(t, task, 0) }) // test ComHandlerAction From 561ccfd82381a9876de6921396234fb527de32c5 Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 11:41:03 +0100 Subject: [PATCH 05/11] test GetTaskFolders --- manage_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/manage_test.go b/manage_test.go index 9375838..2f9245c 100644 --- a/manage_test.go +++ b/manage_test.go @@ -320,11 +320,61 @@ func TestGetRegisteredTasks(t *testing.T) { func TestGetTaskFolders(t *testing.T) { taskService := setupTaskService(t) + for _, leaf := range []struct { + folder []string + task string + }{ + {folder: []string{"Folders", "Alpha"}, task: "TaskOne"}, + {folder: []string{"Folders", "Beta"}, task: "TaskOne"}, + } { + def := taskService.NewTaskDefinition() + def.AddAction(ExecAction{Path: "calc.exe"}) + + pathParts := append([]string{}, leaf.folder...) + pathParts = append(pathParts, leaf.task) + + if _, _, err := taskService.CreateTask(testTaskPath(pathParts...), def, true); err != nil { + t.Fatalf("failed to seed task %v: %v", pathParts, err) + } + } + tf, err := taskService.GetTaskFolders() if err != nil { t.Fatal(err) } - tf.Release() + defer tf.Release() + + var foundTestRoot bool + for _, folder := range tf.SubFolders { + if folder.Path != testTaskRoot { + continue + } + + foundTestRoot = true + queue := append([]*TaskFolder{}, folder.SubFolders...) + leafTasks := map[string]int{} + for len(queue) > 0 { + current := queue[0] + queue = queue[1:] + + if len(current.SubFolders) == 0 { + leafTasks[current.Path] = len(current.RegisteredTasks) + continue + } + + queue = append(queue, current.SubFolders...) + } + + if leafTasks[testTaskPath("Folders", "Alpha")] != 1 || leafTasks[testTaskPath("Folders", "Beta")] != 1 { + t.Fatalf("missing expected leaves or wrong task counts: %v", leafTasks) + } + + break + } + + if !foundTestRoot { + t.Fatalf("did not find %s in folder tree", testTaskRoot) + } } func TestDeleteTask(t *testing.T) { From 072c85253820dbba0e2aa374fd5b31ddcea9a92f Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 11:43:52 +0100 Subject: [PATCH 06/11] test VariantTimeOrZero --- parse_test.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 parse_test.go diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..a52b78c --- /dev/null +++ b/parse_test.go @@ -0,0 +1,29 @@ +//go:build windows +// +build windows + +package taskmaster + +import ( + "math" + "testing" + "time" + + ole "github.com/go-ole/go-ole" +) + +func TestVariantTimeOrZero(t *testing.T) { + if got := variantTimeOrZero(nil); !got.IsZero() { + t.Fatalf("expected zero time for nil variant, got %v", got) + } + + if got := variantTimeOrZero(&ole.VARIANT{VT: ole.VT_I4, Val: 10}); !got.IsZero() { + t.Fatalf("expected zero time for non-date variant, got %v", got) + } + + vtDate := &ole.VARIANT{VT: ole.VT_DATE, Val: int64(math.Float64bits(2.5))} + got := variantTimeOrZero(vtDate) + expected := time.Date(1900, time.January, 1, 12, 0, 0, 0, time.UTC) + if !got.Equal(expected) { + t.Fatalf("expected %v, got %v", expected, got) + } +} From c4d4302cffbf26d463bb7ef0c2cb95aa5cf47d02 Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 11:51:27 +0100 Subject: [PATCH 07/11] test ConnectWithOptions --- manage_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/manage_test.go b/manage_test.go index 2f9245c..cb5cc20 100644 --- a/manage_test.go +++ b/manage_test.go @@ -4,6 +4,7 @@ package taskmaster import ( + "errors" "strings" "testing" "time" @@ -431,3 +432,13 @@ func TestDeleteFolder(t *testing.T) { } } } + +func TestConnectWithOptionsInvalidTarget(t *testing.T) { + _, err := ConnectWithOptions("invalid-taskmaster-host", "", "", "") + if err == nil { + t.Fatal("expected connection failure") + } + if !errors.Is(err, ErrConnectionFailure) { + t.Fatalf("expected ErrConnectionFailure, got %v", err) + } +} From be6259197c968221a4d1eb5731f38e2d31255603 Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 13:28:36 +0100 Subject: [PATCH 08/11] test more properties --- manage_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/manage_test.go b/manage_test.go index cb5cc20..ac1690e 100644 --- a/manage_test.go +++ b/manage_test.go @@ -8,6 +8,8 @@ import ( "strings" "testing" "time" + + "github.com/rickb777/date/period" ) func TestLocalConnect(t *testing.T) { @@ -77,8 +79,10 @@ func TestCreateTask(t *testing.T) { // test DailyTrigger dailyTriggerDef := taskService.NewTaskDefinition() dailyTriggerDef.AddAction(popCalc) + dailyRandomDelay := period.NewHMS(0, 15, 0) dailyTriggerDef.AddTrigger(DailyTrigger{ DayInterval: EveryDay, + RandomDelay: dailyRandomDelay, TaskTrigger: TaskTrigger{ StartBoundary: time.Now(), }, @@ -94,6 +98,9 @@ func TestCreateTask(t *testing.T) { if trigger.DayInterval != EveryDay { t.Fatalf("expected DayInterval %v, got %v", EveryDay, trigger.DayInterval) } + if trigger.RandomDelay.String() != dailyRandomDelay.String() { + t.Fatalf("expected random delay %s, got %s", dailyRandomDelay, trigger.RandomDelay) + } }) // test EventTrigger @@ -227,9 +234,16 @@ func TestCreateTask(t *testing.T) { // test TimeTrigger timeTriggerDef := taskService.NewTaskDefinition() timeTriggerDef.AddAction(popCalc) + repetitionInterval := period.NewHMS(0, 30, 0) + repetitionDuration := period.NewHMS(2, 0, 0) timeTriggerDef.AddTrigger(TimeTrigger{ TaskTrigger: TaskTrigger{ StartBoundary: time.Now(), + RepetitionPattern: RepetitionPattern{ + RepetitionInterval: repetitionInterval, + RepetitionDuration: repetitionDuration, + StopAtDurationEnd: true, + }, }, }) _, _, err = taskService.CreateTask(testTaskPath("TimeTrigger"), timeTriggerDef, true) @@ -243,6 +257,15 @@ func TestCreateTask(t *testing.T) { if trigger.TaskTrigger.StartBoundary.IsZero() { t.Fatal("expected time trigger to have a start boundary") } + if trigger.TaskTrigger.RepetitionInterval.String() != repetitionInterval.String() { + t.Fatalf("expected repetition interval %s, got %s", repetitionInterval, trigger.TaskTrigger.RepetitionInterval) + } + if trigger.TaskTrigger.RepetitionDuration.String() != repetitionDuration.String() { + t.Fatalf("expected repetition duration %s, got %s", repetitionDuration, trigger.TaskTrigger.RepetitionDuration) + } + if !trigger.TaskTrigger.StopAtDurationEnd { + t.Fatal("expected StopAtDurationEnd to be true") + } }) // test WeeklyTrigger From 4fc1420a9228ee2b8394b060830ba3ce5fe8a0f8 Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 13:33:06 +0100 Subject: [PATCH 09/11] test with principals --- manage_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/manage_test.go b/manage_test.go index ac1690e..c18d689 100644 --- a/manage_test.go +++ b/manage_test.go @@ -465,3 +465,71 @@ func TestConnectWithOptionsInvalidTarget(t *testing.T) { t.Fatalf("expected ErrConnectionFailure, got %v", err) } } + +func TestPrincipalSettingsRoundTrip(t *testing.T) { + taskService := setupTaskService(t) + + interactiveUser := taskService.GetConnectedDomain() + `\` + taskService.GetConnectedUser() + + testPrincipals := []struct { + name string + principal Principal + }{ + { + name: "Interactive", + principal: Principal{ + UserID: interactiveUser, + LogonType: TASK_LOGON_INTERACTIVE_TOKEN, + RunLevel: TASK_RUNLEVEL_HIGHEST, + }, + }, + { + name: "System", + principal: Principal{ + UserID: "SYSTEM", + LogonType: TASK_LOGON_SERVICE_ACCOUNT, + RunLevel: TASK_RUNLEVEL_HIGHEST, + }, + }, + } + + for _, tt := range testPrincipals { + def := taskService.NewTaskDefinition() + def.Actions = nil + def.AddAction(ExecAction{Path: "calc.exe"}) + def.Principal = tt.principal + def.Settings.MultipleInstances = TASK_INSTANCES_QUEUE + def.Settings.StopIfGoingOnBatteries = false + + path := testTaskPath("Principal", tt.name) + if _, _, err := taskService.CreateTask(path, def, true); err != nil { + if strings.Contains(err.Error(), "Access is denied") { + if tt.name == "System" { + t.Logf("skipping system principal test due to insufficient privileges: %v", err) + continue + } + t.Skipf("skipping principal test for %s: %v", tt.name, err) + } + t.Fatalf("failed to create task for %s: %v", tt.name, err) + } + + withRegisteredTask(t, taskService, path, func(task RegisteredTask) { + got := task.Definition.Principal + if got.UserID != tt.principal.UserID { + t.Fatalf("principal %s: expected UserID %s, got %s", tt.name, tt.principal.UserID, got.UserID) + } + if got.LogonType != tt.principal.LogonType { + t.Fatalf("principal %s: expected LogonType %d, got %d", tt.name, tt.principal.LogonType, got.LogonType) + } + if got.RunLevel != tt.principal.RunLevel { + t.Fatalf("principal %s: expected RunLevel %d, got %d", tt.name, tt.principal.RunLevel, got.RunLevel) + } + if task.Definition.Settings.MultipleInstances != TASK_INSTANCES_QUEUE { + t.Fatalf("principal %s: expected MultipleInstances %d, got %d", tt.name, TASK_INSTANCES_QUEUE, task.Definition.Settings.MultipleInstances) + } + if task.Definition.Settings.StopIfGoingOnBatteries { + t.Fatalf("principal %s: expected StopIfGoingOnBatteries false", tt.name) + } + }) + } +} From 514793e67b30a7b38861ca0f382501118504c9c0 Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 14:04:33 +0100 Subject: [PATCH 10/11] wrap errors --- manage.go | 74 +++++++++++++++++++++++++++---------------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/manage.go b/manage.go index ea0a9f0..f398192 100644 --- a/manage.go +++ b/manage.go @@ -72,13 +72,13 @@ func ConnectWithOptions(serverName, domain, username, password string) (TaskServ if !taskService.isInitialized { err = taskService.initialize() if err != nil { - return TaskService{}, fmt.Errorf("error initializing ITaskService object: %v", err) + return TaskService{}, fmt.Errorf("error initializing ITaskService object: %w", err) } } _, err = oleutil.CallMethod(taskService.taskServiceObj, "Connect", serverName, username, domain, password) if err != nil { - return TaskService{}, fmt.Errorf("error connecting to Task Scheduler service: %v", getTaskSchedulerError(err)) + return TaskService{}, fmt.Errorf("error connecting to Task Scheduler service: %w", getTaskSchedulerError(err)) } if serverName == "" { @@ -103,7 +103,7 @@ func ConnectWithOptions(serverName, domain, username, password string) (TaskServ res, err := oleutil.CallMethod(taskService.taskServiceObj, "GetFolder", `\`) if err != nil { - return TaskService{}, fmt.Errorf("error getting the root folder: %v", getTaskSchedulerError(err)) + return TaskService{}, fmt.Errorf("error getting the root folder: %w", getTaskSchedulerError(err)) } taskService.rootFolderObj = res.ToIDispatch() taskService.isConnected = true @@ -133,7 +133,7 @@ func (t *TaskService) GetRunningTasks() (RunningTaskCollection, error) { res, err := oleutil.CallMethod(t.taskServiceObj, "GetRunningTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return nil, fmt.Errorf("error getting running tasks: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error getting running tasks: %w", getTaskSchedulerError(err)) } runningTasksObj := res.ToIDispatch() @@ -143,7 +143,7 @@ func (t *TaskService) GetRunningTasks() (RunningTaskCollection, error) { runningTask, err := parseRunningTask(task) if err != nil { - return fmt.Errorf("error parsing running task: %v", err) + return fmt.Errorf("error parsing running task: %w", err) } runningTasks = append(runningTasks, runningTask) @@ -166,7 +166,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { // get tasks from root folder res, err := oleutil.CallMethod(t.rootFolderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return nil, fmt.Errorf("error getting tasks of root folder: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error getting tasks of root folder: %w", getTaskSchedulerError(err)) } rootTaskCollection := res.ToIDispatch() defer rootTaskCollection.Release() @@ -176,7 +176,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } registeredTasks = append(registeredTasks, registeredTask) @@ -188,7 +188,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { res, err = oleutil.CallMethod(t.rootFolderObj, "GetFolders", 0) if err != nil { - return nil, fmt.Errorf("error getting task folders of root folder: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error getting task folders of root folder: %w", getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -201,7 +201,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { res, err := oleutil.CallMethod(taskFolder, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting tasks of folder: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting tasks of folder: %w", getTaskSchedulerError(err)) } taskCollection := res.ToIDispatch() defer taskCollection.Release() @@ -211,7 +211,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } registeredTasks = append(registeredTasks, registeredTask) @@ -223,7 +223,7 @@ func (t *TaskService) GetRegisteredTasks() (RegisteredTaskCollection, error) { res, err = oleutil.CallMethod(taskFolder, "GetFolders", 0) if err != nil { - return fmt.Errorf("error getting subfolders of folder: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting subfolders of folder: %w", getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -254,12 +254,12 @@ func (t *TaskService) GetRegisteredTask(path string) (RegisteredTask, error) { taskObj, err := oleutil.CallMethod(t.rootFolderObj, "GetTask", path) if err != nil { - return RegisteredTask{}, fmt.Errorf("error getting registered task %s: %v", path, getTaskSchedulerError(err)) + return RegisteredTask{}, fmt.Errorf("error getting registered task %s: %w", path, getTaskSchedulerError(err)) } task, _, err := parseRegisteredTask(taskObj.ToIDispatch()) if err != nil { - return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %v", path, err) + return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %w", path, err) } return task, nil @@ -285,7 +285,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { } else { topFolder, err := oleutil.CallMethod(t.taskServiceObj, "GetFolder", path) if err != nil { - return TaskFolder{}, fmt.Errorf("error getting folder %s: %v", path, getTaskSchedulerError(err)) + return TaskFolder{}, fmt.Errorf("error getting folder %s: %w", path, getTaskSchedulerError(err)) } topFolderObj = topFolder.ToIDispatch() defer topFolderObj.Release() @@ -294,7 +294,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { // get tasks from the top folder res, err := oleutil.CallMethod(topFolderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return TaskFolder{}, fmt.Errorf("error getting tasks of folder %s: %v", path, getTaskSchedulerError(err)) + return TaskFolder{}, fmt.Errorf("error getting tasks of folder %s: %w", path, getTaskSchedulerError(err)) } topFolderTaskCollection := res.ToIDispatch() defer topFolderTaskCollection.Release() @@ -304,7 +304,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } topFolder.RegisteredTasks = append(topFolder.RegisteredTasks, registeredTask) @@ -316,7 +316,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { res, err = oleutil.CallMethod(topFolderObj, "GetFolders", 0) if err != nil { - return TaskFolder{}, fmt.Errorf("error getting subfolders of folder %s: %v", path, getTaskSchedulerError(err)) + return TaskFolder{}, fmt.Errorf("error getting subfolders of folder %s: %w", path, getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -333,7 +333,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { path := oleutil.MustGetProperty(taskFolder, "Path").ToString() res, err := oleutil.CallMethod(taskFolder, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting tasks of folder %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error getting tasks of folder %s: %w", path, getTaskSchedulerError(err)) } taskCollection := res.ToIDispatch() defer taskCollection.Release() @@ -348,7 +348,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { registeredTask, path, err := parseRegisteredTask(task) if err != nil { - return fmt.Errorf("error parsing registered task %s: %v", path, err) + return fmt.Errorf("error parsing registered task %s: %w", path, err) } taskSubFolder.RegisteredTasks = append(taskSubFolder.RegisteredTasks, registeredTask) @@ -362,7 +362,7 @@ func (t TaskService) GetTaskFolder(path string) (TaskFolder, error) { res, err = oleutil.CallMethod(taskFolder, "GetFolders", 0) if err != nil { - return fmt.Errorf("error getting subfolders of folder %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error getting subfolders of folder %s: %w", path, getTaskSchedulerError(err)) } taskFolderList := res.ToIDispatch() defer taskFolderList.Release() @@ -445,7 +445,7 @@ func (t *TaskService) CreateTaskEx(path string, newTaskDef Definition, username, if !t.taskFolderExist(folderPath) { _, err = oleutil.CallMethod(t.rootFolderObj, "CreateFolder", folderPath, "") if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error creating folder %s: %v", path, getTaskSchedulerError(err)) + return RegisteredTask{}, false, fmt.Errorf("error creating folder %s: %w", path, getTaskSchedulerError(err)) } } else { if t.registeredTaskExist(path) { @@ -459,19 +459,19 @@ func (t *TaskService) CreateTaskEx(path string, newTaskDef Definition, username, } _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteTask", path, 0) if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error deleting registered task %s: %v", path, getTaskSchedulerError(err)) + return RegisteredTask{}, false, fmt.Errorf("error deleting registered task %s: %w", path, getTaskSchedulerError(err)) } } } newTaskObj, err := t.modifyTask(path, newTaskDef, username, password, logonType, TASK_CREATE) if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error creating registered task %s: %v", path, err) + return RegisteredTask{}, false, fmt.Errorf("error creating registered task %s: %w", path, err) } newTask, _, err := parseRegisteredTask(newTaskObj) if err != nil { - return RegisteredTask{}, false, fmt.Errorf("error parsing registered task %s: %v", path, err) + return RegisteredTask{}, false, fmt.Errorf("error parsing registered task %s: %w", path, err) } return newTask, true, nil @@ -494,13 +494,13 @@ func (t *TaskService) UpdateTaskEx(path string, newTaskDef Definition, username, newTaskObj, err := t.modifyTask(path, newTaskDef, username, password, logonType, TASK_UPDATE) if err != nil { - return RegisteredTask{}, fmt.Errorf("error updating %s task: %v", path, err) + return RegisteredTask{}, fmt.Errorf("error updating %s task: %w", path, err) } // update the internal database of registered tasks newTask, _, err := parseRegisteredTask(newTaskObj) if err != nil { - return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %v", path, err) + return RegisteredTask{}, fmt.Errorf("error parsing registered task %s: %w", path, err) } return newTask, nil @@ -514,19 +514,19 @@ func (t *TaskService) modifyTask(path string, newTaskDef Definition, username, p res, err := oleutil.CallMethod(t.taskServiceObj, "NewTask", 0) if err != nil { - return nil, fmt.Errorf("error creating new task: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error creating new task: %w", getTaskSchedulerError(err)) } newTaskDefObj := res.ToIDispatch() defer newTaskDefObj.Release() err = fillDefinitionObj(newTaskDef, newTaskDefObj) if err != nil { - return nil, fmt.Errorf("error filling ITaskDefinition: %v", err) + return nil, fmt.Errorf("error filling ITaskDefinition: %w", err) } newTaskObj, err := oleutil.CallMethod(t.rootFolderObj, "RegisterTaskDefinition", path, newTaskDefObj, int(flags), username, password, int(logonType), "") if err != nil { - return nil, fmt.Errorf("error registering task: %v", getTaskSchedulerError(err)) + return nil, fmt.Errorf("error registering task: %w", getTaskSchedulerError(err)) } return newTaskObj.ToIDispatch(), nil @@ -544,14 +544,14 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e taskFolder, err := oleutil.CallMethod(t.taskServiceObj, "GetFolder", path) if err != nil { - return false, fmt.Errorf("error getting folder: %v", getTaskSchedulerError(err)) + return false, fmt.Errorf("error getting folder: %w", getTaskSchedulerError(err)) } taskFolderObj := taskFolder.ToIDispatch() defer taskFolderObj.Release() res, err := oleutil.CallMethod(taskFolderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return false, fmt.Errorf("error getting tasks of folder: %v", getTaskSchedulerError(err)) + return false, fmt.Errorf("error getting tasks of folder: %w", getTaskSchedulerError(err)) } taskCollection := res.ToIDispatch() defer taskCollection.Release() @@ -561,7 +561,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e res, err = oleutil.CallMethod(taskFolderObj, "GetFolders", int(TASK_ENUM_HIDDEN)) if err != nil { - return false, fmt.Errorf("error getting the subfolders: %v", getTaskSchedulerError(err)) + return false, fmt.Errorf("error getting the subfolders: %w", getTaskSchedulerError(err)) } folderCollection := res.ToIDispatch() defer folderCollection.Release() @@ -593,7 +593,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e res, err := oleutil.CallMethod(folderObj, "GetTasks", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting tasks of folder: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting tasks of folder: %w", getTaskSchedulerError(err)) } tasks := res.ToIDispatch() defer tasks.Release() @@ -605,7 +605,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e res, err = oleutil.CallMethod(folderObj, "GetFolders", int(TASK_ENUM_HIDDEN)) if err != nil { - return fmt.Errorf("error getting subfolders: %v", getTaskSchedulerError(err)) + return fmt.Errorf("error getting subfolders: %w", getTaskSchedulerError(err)) } subFolders := res.ToIDispatch() defer subFolders.Release() @@ -618,7 +618,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e currentFolderPath := oleutil.MustGetProperty(folderObj, "Path").ToString() _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteFolder", currentFolderPath, 0) if err != nil { - return fmt.Errorf("error deleting task folder %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error deleting task folder %s: %w", path, getTaskSchedulerError(err)) } return nil @@ -634,7 +634,7 @@ func (t *TaskService) DeleteFolder(path string, deleteRecursively bool) (bool, e // delete parent folder _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteFolder", path, 0) if err != nil { - return false, fmt.Errorf("error deleting task folder %s: %v", path, getTaskSchedulerError(err)) + return false, fmt.Errorf("error deleting task folder %s: %w", path, getTaskSchedulerError(err)) } return true, nil @@ -650,7 +650,7 @@ func (t *TaskService) DeleteTask(path string) error { _, err = oleutil.CallMethod(t.rootFolderObj, "DeleteTask", path, 0) if err != nil { - return fmt.Errorf("error deleting task %s: %v", path, getTaskSchedulerError(err)) + return fmt.Errorf("error deleting task %s: %w", path, getTaskSchedulerError(err)) } return nil From f8c106aaa0b178489eee56769cce50a6161c76bb Mon Sep 17 00:00:00 2001 From: fredrik Date: Tue, 11 Nov 2025 14:11:57 +0100 Subject: [PATCH 11/11] accept user without domain --- manage_test.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/manage_test.go b/manage_test.go index c18d689..c8ec949 100644 --- a/manage_test.go +++ b/manage_test.go @@ -469,7 +469,12 @@ func TestConnectWithOptionsInvalidTarget(t *testing.T) { func TestPrincipalSettingsRoundTrip(t *testing.T) { taskService := setupTaskService(t) - interactiveUser := taskService.GetConnectedDomain() + `\` + taskService.GetConnectedUser() + connectedDomain := taskService.GetConnectedDomain() + connectedUser := taskService.GetConnectedUser() + interactiveUserID := connectedUser + if connectedDomain != "" { + interactiveUserID = connectedDomain + `\` + connectedUser + } testPrincipals := []struct { name string @@ -478,7 +483,7 @@ func TestPrincipalSettingsRoundTrip(t *testing.T) { { name: "Interactive", principal: Principal{ - UserID: interactiveUser, + UserID: interactiveUserID, LogonType: TASK_LOGON_INTERACTIVE_TOKEN, RunLevel: TASK_RUNLEVEL_HIGHEST, }, @@ -515,7 +520,11 @@ func TestPrincipalSettingsRoundTrip(t *testing.T) { withRegisteredTask(t, taskService, path, func(task RegisteredTask) { got := task.Definition.Principal - if got.UserID != tt.principal.UserID { + if tt.name == "Interactive" { + if !strings.EqualFold(got.UserID, interactiveUserID) && !strings.EqualFold(got.UserID, connectedUser) { + t.Fatalf("principal %s: expected UserID %s or %s, got %s", tt.name, interactiveUserID, connectedUser, got.UserID) + } + } else if got.UserID != tt.principal.UserID { t.Fatalf("principal %s: expected UserID %s, got %s", tt.name, tt.principal.UserID, got.UserID) } if got.LogonType != tt.principal.LogonType {