diff --git a/Swift/Sources/StateMachine/StateMachine.swift b/Swift/Sources/StateMachine/StateMachine.swift index eb2faca..598eb18 100644 --- a/Swift/Sources/StateMachine/StateMachine.swift +++ b/Swift/Sources/StateMachine/StateMachine.swift @@ -13,15 +13,13 @@ open class StateMachine Void + + private var onEnterActions: [State.HashableIdentifier: EnterExitAction] + private var onExitActions: [State.HashableIdentifier: EnterExitAction] + private var isNotifying: Bool = false public init(@DefinitionBuilder build: () -> Definition) { let definition: Definition = build() state = definition.initialState.state - states = definition.states.reduce(into: States()) { - $0[$1.state] = $1.events.reduce(into: Events()) { - $0[$1.event] = $1.action + var enterActions: [State.HashableIdentifier: EnterExitAction] = [:] + var exitActions: [State.HashableIdentifier: EnterExitAction] = [:] + states = definition.states.reduce(into: States()) { result, tuple in + let (state, events) = tuple + result[state] = events.reduce(into: Events()) { + switch $1.eventType { + case .onEnter(let action): + enterActions[state] = action + case .onExit(let action): + exitActions[state] = action + case .normal(let event, let action): + $0[event] = action + } } } + onEnterActions = enterActions + onExitActions = exitActions observers = definition.callbacks.map { Observer(object: self, callback: $0) } @@ -105,11 +120,19 @@ open class StateMachine Void) -> [EventHandler] { + [EventHandler(eventType: .onEnter(perform))] + } + + public static func onExit(_ perform: @escaping (State) throws -> Void) -> [EventHandler] { + [EventHandler(eventType: .onExit(perform))] + } + + public static func onEnter(_ perform: @escaping () throws -> Void) -> [EventHandler] { + [EventHandler(eventType: .onEnter({ _ in try perform() }))] + } + + public static func onExit(_ perform: @escaping () throws -> Void) -> [EventHandler] { + [EventHandler(eventType: .onExit({ _ in try perform() }))] + } + public static func on( _ event: Event.HashableIdentifier, perform: @escaping (State, Event) throws -> Action ) -> [EventHandler] { - [EventHandler(event: event, action: perform)] + [EventHandler(eventType: .normal(event, perform))] } public static func on( _ event: Event.HashableIdentifier, perform: @escaping (State) throws -> Action ) -> [EventHandler] { - [EventHandler(event: event) { state, _ in try perform(state) }] + [EventHandler(eventType: .normal(event, { state, _ in try perform(state) }))] } public static func on( _ event: Event.HashableIdentifier, perform: @escaping () throws -> Action ) -> [EventHandler] { - [EventHandler(event: event) { _, _ in try perform() }] + [EventHandler(eventType: .normal(event, { _, _ in try perform() }))] } public static func transition( to state: State, - emit sideEffect: SideEffect? = nil + emit sideEffect: SideEffect... ) -> Action { - Action(toState: state, sideEffect: sideEffect) + Action(toState: state, sideEffects: sideEffect) } public static func dontTransition( - emit sideEffect: SideEffect? = nil + emit sideEffect: SideEffect... ) -> Action { - Action(toState: nil, sideEffect: sideEffect) + Action(toState: nil, sideEffects: sideEffect) } public static func onTransition( @@ -279,8 +318,16 @@ public enum StateMachineTypes { public struct EventHandler { - fileprivate let event: Event.HashableIdentifier - fileprivate let action: Action.Factory + fileprivate var eventType: EventType + + fileprivate enum EventType { + + fileprivate typealias EnterExitAction = (State) throws -> Void + + case normal(Event.HashableIdentifier, Action.Factory) + case onEnter(EnterExitAction) + case onExit(EnterExitAction) + } } public struct Action { @@ -288,7 +335,7 @@ public enum StateMachineTypes { fileprivate typealias Factory = (State, Event) throws -> Self fileprivate let toState: State? - fileprivate let sideEffect: SideEffect? + fileprivate let sideEffects: [SideEffect] } public struct IncorrectTypeError: Error, CustomDebugStringConvertible { diff --git a/Swift/Tests/StateMachineTests/StateMachineTests.swift b/Swift/Tests/StateMachineTests/StateMachineTests.swift index 9813b34..5c3aade 100644 --- a/Swift/Tests/StateMachineTests/StateMachineTests.swift +++ b/Swift/Tests/StateMachineTests/StateMachineTests.swift @@ -11,17 +11,17 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { enum State: StateMachineHashable { - case stateOne, stateTwo + case stateOne, stateTwo, stateThree } enum Event: StateMachineHashable { - case eventOne, eventTwo + case eventOne, eventTwo, eventThree } - enum SideEffect { + enum SideEffect: Equatable { - case commandOne, commandTwo, commandThree + case commandOne, commandTwo, commandThree, commandFour(Int) } typealias TestStateMachine = StateMachine @@ -33,7 +33,7 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { initialState(_state) state(.stateOne) { on(.eventOne) { - dontTransition(emit: .commandOne) + dontTransition(emit: .commandOne, .commandTwo) } on(.eventTwo) { transition(to: .stateTwo, emit: .commandTwo) @@ -43,7 +43,11 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { on(.eventTwo) { dontTransition(emit: .commandThree) } + on(.eventThree) { _, event in + transition(to: .stateThree, emit: .commandFour(try event.string())) + } } + state(.stateThree) } } @@ -66,7 +70,7 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .stateOne, event: .eventOne, toState: .stateOne, - sideEffect: .commandOne))) + sideEffects: [.commandOne, .commandTwo]))) } func testTransition() throws { @@ -82,7 +86,7 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .stateOne, event: .eventTwo, toState: .stateTwo, - sideEffect: .commandTwo))) + sideEffects: [.commandTwo]))) } func testInvalidTransition() throws { @@ -131,16 +135,16 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { .success(ValidTransition(fromState: .stateOne, event: .eventOne, toState: .stateOne, - sideEffect: .commandOne)), + sideEffects: [.commandOne, .commandTwo])), .success(ValidTransition(fromState: .stateOne, event: .eventTwo, toState: .stateTwo, - sideEffect: .commandTwo)), + sideEffects: [.commandTwo])), .failure(InvalidTransition()), .success(ValidTransition(fromState: .stateTwo, event: .eventTwo, toState: .stateTwo, - sideEffect: .commandThree)) + sideEffects: [.commandThree])) ])) } @@ -191,6 +195,14 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder { // Then expect(error).to(equal(.recursionDetected)) } + + func testGettingNonExistingValue() throws { + // Given + let stateMachine: TestStateMachine = givenState(is: .stateTwo) + + // Then + XCTAssertThrowsError(try stateMachine.transition(.eventThree)) + } } final class Logger { @@ -212,3 +224,13 @@ func log(_ expectedMessages: String...) -> Predicate { return PredicateResult(bool: actualMessages == expectedMessages, message: message) } } + +func noLog() -> Predicate { + return Predicate { + let actualMessages: [String]? = try $0.evaluate()?.messages + let actualString: String = stringify(actualMessages?.joined(separator: "\\n")) + let message: ExpectationMessage = .expectedCustomValueTo("no logs", + actual: "<\(actualString)>") + return PredicateResult(bool: actualString.count == 0, message: message) + } +} diff --git a/Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift b/Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift index d2d926e..a64e0a2 100644 --- a/Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift +++ b/Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift @@ -28,23 +28,41 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder { typealias ValidTransition = MatterStateMachine.Transition.Valid typealias InvalidTransition = MatterStateMachine.Transition.Invalid - enum Message { - - static let melted: String = "I melted" - static let frozen: String = "I froze" - static let vaporized: String = "I vaporized" - static let condensed: String = "I condensed" + enum Message: String { + + case melted = "I melted" + case frozen = "I froze" + case vaporized = "I vaporized" + case condensed = "I condensed" + case enteredSolid + case exitedSolid + case enteredLiquid + case exitedLiquid + case enteredGas + case exitedGas } static func matterStateMachine(withInitialState _state: State, logger: Logger) -> MatterStateMachine { MatterStateMachine { initialState(_state) state(.solid) { + onEnter { _ in + logger.log(Message.enteredSolid.rawValue) + } + onExit { _ in + logger.log(Message.exitedSolid.rawValue) + } on(.melt) { transition(to: .liquid, emit: .logMelted) } } state(.liquid) { + onEnter { _ in + logger.log(Message.enteredLiquid.rawValue) + } + onExit { _ in + logger.log(Message.exitedLiquid.rawValue) + } on(.freeze) { transition(to: .solid, emit: .logFrozen) } @@ -53,17 +71,25 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder { } } state(.gas) { + onEnter { _ in + logger.log(Message.enteredGas.rawValue) + } + onExit { _ in + logger.log(Message.exitedGas.rawValue) + } on(.condense) { transition(to: .liquid, emit: .logCondensed) } } onTransition { - guard case let .success(transition) = $0, let sideEffect = transition.sideEffect else { return } - switch sideEffect { - case .logMelted: logger.log(Message.melted) - case .logFrozen: logger.log(Message.frozen) - case .logVaporized: logger.log(Message.vaporized) - case .logCondensed: logger.log(Message.condensed) + guard case let .success(transition) = $0 else { return } + transition.sideEffects.forEach { sideEffect in + switch sideEffect { + case .logMelted: logger.log(Message.melted.rawValue) + case .logFrozen: logger.log(Message.frozen.rawValue) + case .logVaporized: logger.log(Message.vaporized.rawValue) + case .logCondensed: logger.log(Message.condensed.rawValue) + } } } } @@ -100,8 +126,8 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .solid, event: .melt, toState: .liquid, - sideEffect: .logMelted))) - expect(self.logger).to(log(Message.melted)) + sideEffects: [.logMelted]))) + expect(self.logger).to(log(Message.exitedSolid.rawValue, Message.enteredLiquid.rawValue, Message.melted.rawValue)) } func test_givenStateIsSolid_whenFrozen_shouldThrowInvalidTransitionError() throws { @@ -133,8 +159,8 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .liquid, event: .freeze, toState: .solid, - sideEffect: .logFrozen))) - expect(self.logger).to(log(Message.frozen)) + sideEffects: [.logFrozen]))) + expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredSolid.rawValue, Message.frozen.rawValue)) } func test_givenStateIsLiquid_whenVaporized_shouldTransitionToGasState() throws { @@ -150,8 +176,8 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .liquid, event: .vaporize, toState: .gas, - sideEffect: .logVaporized))) - expect(self.logger).to(log(Message.vaporized)) + sideEffects: [.logVaporized]))) + expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredGas.rawValue, Message.vaporized.rawValue)) } func test_givenStateIsGas_whenCondensed_shouldTransitionToLiquidState() throws { @@ -167,7 +193,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .gas, event: .condense, toState: .liquid, - sideEffect: .logCondensed))) - expect(self.logger).to(log(Message.condensed)) + sideEffects: [.logCondensed]))) + expect(self.logger).to(log(Message.exitedGas.rawValue, Message.enteredLiquid.rawValue, Message.condensed.rawValue)) } } diff --git a/Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift b/Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift index 38fbf33..4b8f4a9 100644 --- a/Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift +++ b/Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift @@ -32,10 +32,25 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { typealias TurnstileStateMachine = StateMachine typealias ValidTransition = TurnstileStateMachine.Transition.Valid + enum Message: String { + case enteredLocked + case exitedLocked + case enteredUnlocked + case exitedUnlocked + case enteredBroken + case exitedBroken + } + static func turnstileStateMachine(withInitialState _state: State, logger: Logger) -> TurnstileStateMachine { TurnstileStateMachine { initialState(_state) state(.locked) { + onEnter { state in + logger.log("\(Message.enteredLocked.rawValue) \(try state.credit() as Int)") + } + onExit { + logger.log(Message.exitedLocked.rawValue) + } on(.insertCoin) { locked, insertCoin in let newCredit: Int = try locked.credit() + insertCoin.value() if newCredit >= Constant.farePrice { @@ -52,11 +67,23 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { } } state(.unlocked) { + onEnter { + logger.log(Message.enteredUnlocked.rawValue) + } + onExit { + logger.log(Message.exitedUnlocked.rawValue) + } on(.admitPerson) { transition(to: .locked(credit: 0), emit: .closeDoors) } } state(.broken) { + onEnter { + logger.log(Message.enteredBroken.rawValue) + } + onExit { + logger.log(Message.exitedBroken.rawValue) + } on(.machineRepairDidComplete) { broken in transition(to: try broken.oldState()) } @@ -95,7 +122,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .locked(credit: 0), event: .insertCoin(10), toState: .locked(credit: 10), - sideEffect: nil))) + sideEffects: []))) + expect(self.logger).to(log(Message.exitedLocked.rawValue, "\(Message.enteredLocked.rawValue) 10")) } func test_givenStateIsLocked_whenInsertCoin_andCreditEqualsFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws { @@ -111,7 +139,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .locked(credit: 35), event: .insertCoin(15), toState: .unlocked, - sideEffect: .openDoors))) + sideEffects: [.openDoors]))) + expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue)) } func test_givenStateIsLocked_whenInsertCoin_andCreditMoreThanFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws { @@ -127,7 +156,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .locked(credit: 35), event: .insertCoin(20), toState: .unlocked, - sideEffect: .openDoors))) + sideEffects: [.openDoors]))) + expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue)) } func test_givenStateIsLocked_whenAdmitPerson_shouldTransitionToLockedStateAndSoundAlarm() throws { @@ -143,7 +173,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .locked(credit: 35), event: .admitPerson, toState: .locked(credit: 35), - sideEffect: .soundAlarm))) + sideEffects: [.soundAlarm]))) + expect(self.logger).to(noLog()) } func test_givenStateIsLocked_whenMachineDidFail_shouldTransitionToBrokenStateAndOrderRepair() throws { @@ -159,7 +190,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .locked(credit: 15), event: .machineDidFail, toState: .broken(oldState: .locked(credit: 15)), - sideEffect: .orderRepair))) + sideEffects: [.orderRepair]))) + expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredBroken.rawValue)) } func test_givenStateIsUnlocked_whenAdmitPerson_shouldTransitionToLockedStateAndCloseDoors() throws { @@ -175,7 +207,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .unlocked, event: .admitPerson, toState: .locked(credit: 0), - sideEffect: .closeDoors))) + sideEffects: [.closeDoors]))) + expect(self.logger).to(log(Message.exitedUnlocked.rawValue, "\(Message.enteredLocked.rawValue) 0")) } func test_givenStateIsBroken_whenMachineRepairDidComplete_shouldTransitionToLockedState() throws { @@ -191,7 +224,8 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder { expect(transition).to(equal(ValidTransition(fromState: .broken(oldState: .locked(credit: 15)), event: .machineRepairDidComplete, toState: .locked(credit: 15), - sideEffect: nil))) + sideEffects: []))) + expect(self.logger).to(log(Message.exitedBroken.rawValue, "\(Message.enteredLocked.rawValue) 15")) } }