diff --git a/client/src/main/scala/tmt/views/SubscriptionView.scala b/client/src/main/scala/tmt/views/SubscriptionView.scala index fbcd4e2..9126ae7 100644 --- a/client/src/main/scala/tmt/views/SubscriptionView.scala +++ b/client/src/main/scala/tmt/views/SubscriptionView.scala @@ -5,7 +5,7 @@ import rx._ import tmt.framework.Framework._ import tmt.framework.Helpers._ import tmt.shared.models.{Connection, ConnectionSet, RoleMappings} - +import monifu.concurrent.Implicits.globalScheduler import scalatags.JsDom.all._ class SubscriptionView(roleMappings: RoleMappings, connectionSet: ConnectionSet) { @@ -50,8 +50,9 @@ class SubscriptionView(roleMappings: RoleMappings, connectionSet: ConnectionSet) } def addConnection() = { - subscribe(connection()) - connections() = connections() + connection() + subscribe(connection()).onSuccess { + case _ => connections() = connections() + connection() + } } def removeConnection(connection: Connection) = { diff --git a/common/src/main/scala/tmt/library/Role.scala b/common/src/main/scala/tmt/library/Role.scala index 379fe4e..db8c5e0 100644 --- a/common/src/main/scala/tmt/library/Role.scala +++ b/common/src/main/scala/tmt/library/Role.scala @@ -2,7 +2,7 @@ package tmt.library import enumeratum.{Enum, EnumEntry} -sealed abstract class Role(maybeConsumes: Option[ItemType], maybeProduces: Option[ItemType]) extends EnumEntry +sealed abstract class Role(val maybeConsumes: Option[ItemType], val maybeProduces: Option[ItemType]) extends EnumEntry sealed abstract class SourceRole(override val entryName: String, val produces: ItemType) extends Role(None, Some(produces)) sealed abstract class SinkRole(override val entryName: String, val consumes: ItemType) extends Role(Some(consumes), None) sealed abstract class FlowRole(override val entryName: String, val consumes: ItemType, val produces: ItemType) extends Role(Some(consumes), Some(produces)) diff --git a/frontend/app/controllers/StreamController.scala b/frontend/app/controllers/StreamController.scala index bcf4808..df34866 100644 --- a/frontend/app/controllers/StreamController.scala +++ b/frontend/app/controllers/StreamController.scala @@ -4,7 +4,7 @@ import javax.inject.{Inject, Singleton} import common.AppSettings import play.api.mvc.{Action, Controller} -import services.{ClusterClientService, ConnectionSetService, RoleMappingsService} +import services.{ValidationFailedException, ClusterClientService, ConnectionSetService, RoleMappingsService} import templates.Page import upickle.default._ @@ -46,7 +46,10 @@ class StreamController @Inject()( def subscribe(serverName: String, topic: String) = Action { clusterClientService.subscribe(serverName, topic) - Accepted("ok") + .map(_ => Accepted("ok")) + .recover { + case ValidationFailedException(msg) => BadGateway(msg) + }.get } def unsubscribe(serverName: String, topic: String) = Action { diff --git a/frontend/app/services/ClusterClientService.scala b/frontend/app/services/ClusterClientService.scala index 68f46ee..3113d3a 100644 --- a/frontend/app/services/ClusterClientService.scala +++ b/frontend/app/services/ClusterClientService.scala @@ -9,13 +9,15 @@ import akka.cluster.pubsub.DistributedPubSubMediator.Publish import akka.pattern.ask import akka.util.Timeout import tmt.common.Messages +import tmt.library.Role import tmt.shared.Topics import tmt.shared.models.ConnectionSet import scala.concurrent.duration.{DurationInt, FiniteDuration} +import scala.util.{Failure, Success, Try} @Singleton -class ClusterClientService @Inject()(system: ActorSystem) { +class ClusterClientService @Inject()(system: ActorSystem, roleMappingsService: RoleMappingsService) { implicit val timeout = Timeout(2.seconds) @@ -26,8 +28,13 @@ class ClusterClientService @Inject()(system: ActorSystem) { mediator ! Publish(Topics.Throttle, Messages.UpdateDelay(serverName, delay)) } - def subscribe(serverName: String, topic: String) = { - mediator ! Publish(Topics.Subscription, Messages.Subscribe(serverName, topic)) + def subscribe(serverName: String, topic: String): Try[Unit] = { + validate(serverName, topic) match { + case Some(true) => + Success(mediator ! Publish(Topics.Subscription, Messages.Subscribe(serverName, topic))) + case _ => + Failure(new ValidationFailedException("Bad request")) + } } def unsubscribe(serverName: String, topic: String) = { @@ -35,4 +42,17 @@ class ClusterClientService @Inject()(system: ActorSystem) { } def allConnections = (connectionStore ? ConnectionStore.GetConnections).mapTo[ConnectionSet] + + private def validate(sourceServerName: String, destinationServerName: String) = { + val roleMappings = roleMappingsService.onlineRoleMappings + + for { + sourceRoleName <- roleMappings.roleOf(sourceServerName) + destinationRoleName <- roleMappings.roleOf(destinationServerName) + sourceRole = Role.withName(sourceRoleName) + destinationRole = Role.withName(destinationRoleName) + } yield sourceRole.maybeConsumes.isDefined && sourceRole.maybeConsumes == destinationRole.maybeProduces + } } + +case class ValidationFailedException(msg: String) extends RuntimeException(msg) diff --git a/shared/src/main/scala/tmt/shared/models/RoleMappings.scala b/shared/src/main/scala/tmt/shared/models/RoleMappings.scala index 684eef6..e79f40d 100644 --- a/shared/src/main/scala/tmt/shared/models/RoleMappings.scala +++ b/shared/src/main/scala/tmt/shared/models/RoleMappings.scala @@ -8,4 +8,10 @@ case class RoleMappings(mappings: Map[String, Seq[String]]) { role -> serverNames.filter(onlineRoles) } } + def roleOf(serverName: String) = { + mappings.find { mapping => + val servers = mapping._2 + servers.contains(serverName) + }.map(_._1) + } }