diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 109104f0a5..3f912dc191 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -200,9 +200,13 @@ private[spark] object TestUtils { /** * Returns the response code from an HTTP(S) URL. */ - def httpResponseCode(url: URL, method: String = "GET"): Int = { + def httpResponseCode( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): Int = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod(method) + headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } // Disable cert and host name validation for HTTPS tests. if (connection.isInstanceOf[HttpsURLConnection]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 17bc04303f..67ccf43afa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1 import java.util.zip.ZipOutputStream import javax.servlet.ServletContext +import javax.servlet.http.HttpServletRequest import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class ApiRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJobs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs") def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs/{jobId: \\d+}") def getJob(@PathParam("appId") appId: String): OneJobResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneJobResource(ui) } } @@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJob( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneJobResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneJobResource(ui) } } @Path("applications/{appId}/executors") def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ExecutorListResource(ui) } } @Path("applications/{appId}/allexecutors") def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllExecutorListResource(ui) } } @@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ExecutorListResource(ui) } } @@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getAllExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllExecutorListResource(ui) } } - @Path("applications/{appId}/stages") def getStages(@PathParam("appId") appId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllStagesResource(ui) } } @@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStages( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") def getStage(@PathParam("appId") appId: String): OneStageResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneStageResource(ui) } } @@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStage( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneStageResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneStageResource(ui) } } @Path("applications/{appId}/storage/rdd") def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllRDDResource(ui) } } @@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdds( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllRDDResource(ui) } } @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneRDDResource(ui) } } @@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdd( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneRDDResource(ui) } } @@ -234,19 +234,6 @@ private[spark] trait UIRoot { .status(Response.Status.SERVICE_UNAVAILABLE) .build() } - - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - getSparkUI(appKey) match { - case Some(ui) => - f(ui) - case None => throw new NotFoundException("no such app: " + appId) - } - } def securityManager: SecurityManager } @@ -263,13 +250,38 @@ private[v1] object UIRootFromServletContext { } } -private[v1] trait UIRootFromServletContext { +private[v1] trait ApiRequestContext { @Context - var servletContext: ServletContext = _ + protected var servletContext: ServletContext = _ + + @Context + protected var httpRequest: HttpServletRequest = _ def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) + + + /** + * Get the spark UI with the given appID, and apply a function + * to it. If there is no such app, throw an appropriate exception + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + uiRoot.getSparkUI(appKey) match { + case Some(ui) => + val user = httpRequest.getRemoteUser() + if (!ui.securityManager.checkUIViewPermissions(user)) { + throw new ForbiddenException(raw"""user "$user" is not authorized""") + } + f(ui) + case None => throw new NotFoundException("no such app: " + appId) + } + } + } +private[v1] class ForbiddenException(msg: String) extends WebApplicationException( + Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + private[v1] class NotFoundException(msg: String) extends WebApplicationException( new NoSuchElementException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala index b4a991eda3..1cd37185d6 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -21,14 +21,14 @@ import javax.ws.rs.core.Response import javax.ws.rs.ext.Provider @Provider -private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { +private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext { override def filter(req: ContainerRequestContext): Unit = { - val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull + val user = httpRequest.getRemoteUser() if (!uiRoot.securityManager.checkUIViewPermissions(user)) { req.abortWith( Response .status(Response.Status.FORBIDDEN) - .entity(raw"""user "$user"is not authorized""") + .entity(raw"""user "$user" is not authorized""") .build() ) } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 7909821db9..bdbdba5780 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -90,9 +90,9 @@ private[spark] object JettyUtils extends Logging { response.setHeader("X-Frame-Options", xFrameOptionsValue) response.getWriter.print(servletParams.extractFn(result)) } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setStatus(HttpServletResponse.SC_FORBIDDEN) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + response.sendError(HttpServletResponse.SC_FORBIDDEN, "User is not authorized to access this page.") } } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index b2eded43ba..dcf83cb530 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} import scala.concurrent.duration._ import scala.language.postfixOps @@ -68,11 +69,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private var server: HistoryServer = null private var port: Int = -1 - def init(): Unit = { + def init(extraConf: (String, String)*): Unit = { val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() val securityManager = HistoryServer.createSecurityManager(conf) @@ -566,6 +568,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } + test("ui and api authorization checks") { + val appId = "app-20161115172038-0000" + val owner = "jose" + val admin = "root" + val other = "alice" + + stop() + init( + "spark.ui.filters" -> classOf[FakeAuthFilter].getName(), + "spark.history.ui.acls.enable" -> "true", + "spark.history.ui.admin.acls" -> admin) + + val tests = Seq( + (owner, HttpServletResponse.SC_OK), + (admin, HttpServletResponse.SC_OK), + (other, HttpServletResponse.SC_FORBIDDEN), + // When the remote user is null, the code behaves as if auth were disabled. + (null, HttpServletResponse.SC_OK)) + + val port = server.boundPort + val testUrls = Seq( + s"http://localhost:$port/api/v1/applications/$appId/jobs", + s"http://localhost:$port/history/$appId/jobs/") + + tests.foreach { case (user, expectedCode) => + testUrls.foreach { url => + val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil + val sc = TestUtils.httpResponseCode(new URL(url), headers = headers) + assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)") + } + } + } + def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path")) } @@ -648,3 +683,26 @@ object HistoryServerSuite { } } } + +/** + * A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header. + */ +class FakeAuthFilter extends Filter { + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val wrapped = new HttpServletRequestWrapper(hreq) { + override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER) + } + chain.doFilter(wrapped, res) + } + +} + +object FakeAuthFilter { + val FAKE_HTTP_USER = "HTTP_USER" +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9d359427f2..511686fb4f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( + // [SPARK-19652][UI] Do auth checks for REST API access. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"), + // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala index e64830a945..aea75d5a9c 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala @@ -19,14 +19,14 @@ package org.apache.spark.status.api.v1.streaming import javax.ws.rs.{Path, PathParam} -import org.apache.spark.status.api.v1.UIRootFromServletContext +import org.apache.spark.status.api.v1.ApiRequestContext @Path("/v1") -private[v1] class ApiStreamingApp extends UIRootFromServletContext { +private[v1] class ApiStreamingApp extends ApiRequestContext { @Path("applications/{appId}/streaming") def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ApiStreamingRootResource(ui) } } @@ -35,7 +35,7 @@ private[v1] class ApiStreamingApp extends UIRootFromServletContext { def getStreamingRoot( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ApiStreamingRootResource(ui) } }