[SPARK-19652][UI] Do auth checks for REST API access.

The REST API has a security filter that performs auth checks
based on the UI root's security manager. That works fine when
the UI root is the app's UI, but not when it's the history server.

In the SHS case, all users would be allowed to see all applications
through the REST API, even if the UI itself wouldn't be available
to them.

This change adds auth checks for each app access through the API
too, so that only authorized users can see the app's data.

The change also modifies the existing security filter to use
`HttpServletRequest.getRemoteUser()`, which is used in other
places. That is not necessarily the same as the principal's
name; for example, when using Hadoop's SPNEGO auth filter,
the remote user strips the realm information, which then matches
the user name registered as the owner of the application.

I also renamed the UIRootFromServletContext trait to a more generic
name since I'm using it to store more context information now.

Tested manually with an authentication filter enabled.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #16978 from vanzin/SPARK-19652.
This commit is contained in:
Marcelo Vanzin 2017-02-21 16:14:34 -08:00
parent 7363dde634
commit 17d83e1ee5
7 changed files with 123 additions and 45 deletions

View file

@ -200,9 +200,13 @@ private[spark] object TestUtils {
/** /**
* Returns the response code from an HTTP(S) URL. * Returns the response code from an HTTP(S) URL.
*/ */
def httpResponseCode(url: URL, method: String = "GET"): Int = { def httpResponseCode(
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection] val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method) connection.setRequestMethod(method)
headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }
// Disable cert and host name validation for HTTPS tests. // Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) { if (connection.isInstanceOf[HttpsURLConnection]) {

View file

@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1
import java.util.zip.ZipOutputStream import java.util.zip.ZipOutputStream
import javax.servlet.ServletContext import javax.servlet.ServletContext
import javax.servlet.http.HttpServletRequest
import javax.ws.rs._ import javax.ws.rs._
import javax.ws.rs.core.{Context, Response} import javax.ws.rs.core.{Context, Response}
@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI
* HistoryServerSuite. * HistoryServerSuite.
*/ */
@Path("/v1") @Path("/v1")
private[v1] class ApiRootResource extends UIRootFromServletContext { private[v1] class ApiRootResource extends ApiRequestContext {
@Path("applications") @Path("applications")
def getApplicationList(): ApplicationListResource = { def getApplicationList(): ApplicationListResource = {
@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJobs( def getJobs(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllJobsResource = { @PathParam("attemptId") attemptId: String): AllJobsResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new AllJobsResource(ui) new AllJobsResource(ui)
} }
} }
@Path("applications/{appId}/jobs") @Path("applications/{appId}/jobs")
def getJobs(@PathParam("appId") appId: String): AllJobsResource = { def getJobs(@PathParam("appId") appId: String): AllJobsResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new AllJobsResource(ui) new AllJobsResource(ui)
} }
} }
@Path("applications/{appId}/jobs/{jobId: \\d+}") @Path("applications/{appId}/jobs/{jobId: \\d+}")
def getJob(@PathParam("appId") appId: String): OneJobResource = { def getJob(@PathParam("appId") appId: String): OneJobResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new OneJobResource(ui) new OneJobResource(ui)
} }
} }
@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJob( def getJob(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneJobResource = { @PathParam("attemptId") attemptId: String): OneJobResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new OneJobResource(ui) new OneJobResource(ui)
} }
} }
@Path("applications/{appId}/executors") @Path("applications/{appId}/executors")
def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new ExecutorListResource(ui) new ExecutorListResource(ui)
} }
} }
@Path("applications/{appId}/allexecutors") @Path("applications/{appId}/allexecutors")
def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new AllExecutorListResource(ui) new AllExecutorListResource(ui)
} }
} }
@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getExecutors( def getExecutors(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): ExecutorListResource = { @PathParam("attemptId") attemptId: String): ExecutorListResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new ExecutorListResource(ui) new ExecutorListResource(ui)
} }
} }
@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getAllExecutors( def getAllExecutors(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllExecutorListResource = { @PathParam("attemptId") attemptId: String): AllExecutorListResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new AllExecutorListResource(ui) new AllExecutorListResource(ui)
} }
} }
@Path("applications/{appId}/stages") @Path("applications/{appId}/stages")
def getStages(@PathParam("appId") appId: String): AllStagesResource = { def getStages(@PathParam("appId") appId: String): AllStagesResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new AllStagesResource(ui) new AllStagesResource(ui)
} }
} }
@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStages( def getStages(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllStagesResource = { @PathParam("attemptId") attemptId: String): AllStagesResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new AllStagesResource(ui) new AllStagesResource(ui)
} }
} }
@Path("applications/{appId}/stages/{stageId: \\d+}") @Path("applications/{appId}/stages/{stageId: \\d+}")
def getStage(@PathParam("appId") appId: String): OneStageResource = { def getStage(@PathParam("appId") appId: String): OneStageResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new OneStageResource(ui) new OneStageResource(ui)
} }
} }
@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStage( def getStage(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneStageResource = { @PathParam("attemptId") attemptId: String): OneStageResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new OneStageResource(ui) new OneStageResource(ui)
} }
} }
@Path("applications/{appId}/storage/rdd") @Path("applications/{appId}/storage/rdd")
def getRdds(@PathParam("appId") appId: String): AllRDDResource = { def getRdds(@PathParam("appId") appId: String): AllRDDResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new AllRDDResource(ui) new AllRDDResource(ui)
} }
} }
@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdds( def getRdds(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllRDDResource = { @PathParam("attemptId") attemptId: String): AllRDDResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new AllRDDResource(ui) new AllRDDResource(ui)
} }
} }
@Path("applications/{appId}/storage/rdd/{rddId: \\d+}") @Path("applications/{appId}/storage/rdd/{rddId: \\d+}")
def getRdd(@PathParam("appId") appId: String): OneRDDResource = { def getRdd(@PathParam("appId") appId: String): OneRDDResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new OneRDDResource(ui) new OneRDDResource(ui)
} }
} }
@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdd( def getRdd(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneRDDResource = { @PathParam("attemptId") attemptId: String): OneRDDResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new OneRDDResource(ui) new OneRDDResource(ui)
} }
} }
@ -234,19 +234,6 @@ private[spark] trait UIRoot {
.status(Response.Status.SERVICE_UNAVAILABLE) .status(Response.Status.SERVICE_UNAVAILABLE)
.build() .build()
} }
/**
* Get the spark UI with the given appID, and apply a function
* to it. If there is no such app, throw an appropriate exception
*/
def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
getSparkUI(appKey) match {
case Some(ui) =>
f(ui)
case None => throw new NotFoundException("no such app: " + appId)
}
}
def securityManager: SecurityManager def securityManager: SecurityManager
} }
@ -263,13 +250,38 @@ private[v1] object UIRootFromServletContext {
} }
} }
private[v1] trait UIRootFromServletContext { private[v1] trait ApiRequestContext {
@Context @Context
var servletContext: ServletContext = _ protected var servletContext: ServletContext = _
@Context
protected var httpRequest: HttpServletRequest = _
def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext)
/**
* Get the spark UI with the given appID, and apply a function
* to it. If there is no such app, throw an appropriate exception
*/
def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
uiRoot.getSparkUI(appKey) match {
case Some(ui) =>
val user = httpRequest.getRemoteUser()
if (!ui.securityManager.checkUIViewPermissions(user)) {
throw new ForbiddenException(raw"""user "$user" is not authorized""")
}
f(ui)
case None => throw new NotFoundException("no such app: " + appId)
}
}
} }
private[v1] class ForbiddenException(msg: String) extends WebApplicationException(
Response.status(Response.Status.FORBIDDEN).entity(msg).build())
private[v1] class NotFoundException(msg: String) extends WebApplicationException( private[v1] class NotFoundException(msg: String) extends WebApplicationException(
new NoSuchElementException(msg), new NoSuchElementException(msg),
Response Response

View file

@ -21,14 +21,14 @@ import javax.ws.rs.core.Response
import javax.ws.rs.ext.Provider import javax.ws.rs.ext.Provider
@Provider @Provider
private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext {
override def filter(req: ContainerRequestContext): Unit = { override def filter(req: ContainerRequestContext): Unit = {
val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull val user = httpRequest.getRemoteUser()
if (!uiRoot.securityManager.checkUIViewPermissions(user)) { if (!uiRoot.securityManager.checkUIViewPermissions(user)) {
req.abortWith( req.abortWith(
Response Response
.status(Response.Status.FORBIDDEN) .status(Response.Status.FORBIDDEN)
.entity(raw"""user "$user"is not authorized""") .entity(raw"""user "$user" is not authorized""")
.build() .build()
) )
} }

View file

@ -90,9 +90,9 @@ private[spark] object JettyUtils extends Logging {
response.setHeader("X-Frame-Options", xFrameOptionsValue) response.setHeader("X-Frame-Options", xFrameOptionsValue)
response.getWriter.print(servletParams.extractFn(result)) response.getWriter.print(servletParams.extractFn(result))
} else { } else {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setStatus(HttpServletResponse.SC_FORBIDDEN)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, response.sendError(HttpServletResponse.SC_FORBIDDEN,
"User is not authorized to access this page.") "User is not authorized to access this page.")
} }
} catch { } catch {

View file

@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException}
import java.net.{HttpURLConnection, URL} import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.util.zip.ZipInputStream import java.util.zip.ZipInputStream
import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import javax.servlet._
import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse}
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.language.postfixOps import scala.language.postfixOps
@ -68,11 +69,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
private var server: HistoryServer = null private var server: HistoryServer = null
private var port: Int = -1 private var port: Int = -1
def init(): Unit = { def init(extraConf: (String, String)*): Unit = {
val conf = new SparkConf() val conf = new SparkConf()
.set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.logDirectory", logDir)
.set("spark.history.fs.update.interval", "0") .set("spark.history.fs.update.interval", "0")
.set("spark.testing", "true") .set("spark.testing", "true")
conf.setAll(extraConf)
provider = new FsHistoryProvider(conf) provider = new FsHistoryProvider(conf)
provider.checkForLogs() provider.checkForLogs()
val securityManager = HistoryServer.createSecurityManager(conf) val securityManager = HistoryServer.createSecurityManager(conf)
@ -566,6 +568,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
} }
test("ui and api authorization checks") {
val appId = "app-20161115172038-0000"
val owner = "jose"
val admin = "root"
val other = "alice"
stop()
init(
"spark.ui.filters" -> classOf[FakeAuthFilter].getName(),
"spark.history.ui.acls.enable" -> "true",
"spark.history.ui.admin.acls" -> admin)
val tests = Seq(
(owner, HttpServletResponse.SC_OK),
(admin, HttpServletResponse.SC_OK),
(other, HttpServletResponse.SC_FORBIDDEN),
// When the remote user is null, the code behaves as if auth were disabled.
(null, HttpServletResponse.SC_OK))
val port = server.boundPort
val testUrls = Seq(
s"http://localhost:$port/api/v1/applications/$appId/jobs",
s"http://localhost:$port/history/$appId/jobs/")
tests.foreach { case (user, expectedCode) =>
testUrls.foreach { url =>
val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil
val sc = TestUtils.httpResponseCode(new URL(url), headers = headers)
assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)")
}
}
}
def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = {
HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path")) HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path"))
} }
@ -648,3 +683,26 @@ object HistoryServerSuite {
} }
} }
} }
/**
* A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header.
*/
class FakeAuthFilter extends Filter {
override def destroy(): Unit = { }
override def init(config: FilterConfig): Unit = { }
override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = {
val hreq = req.asInstanceOf[HttpServletRequest]
val wrapped = new HttpServletRequestWrapper(hreq) {
override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER)
}
chain.doFilter(wrapped, res)
}
}
object FakeAuthFilter {
val FAKE_HTTP_USER = "HTTP_USER"
}

View file

@ -36,6 +36,10 @@ object MimaExcludes {
// Exclude rules for 2.2.x // Exclude rules for 2.2.x
lazy val v22excludes = v21excludes ++ Seq( lazy val v22excludes = v21excludes ++ Seq(
// [SPARK-19652][UI] Do auth checks for REST API access.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"),
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"),
// [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"),

View file

@ -19,14 +19,14 @@ package org.apache.spark.status.api.v1.streaming
import javax.ws.rs.{Path, PathParam} import javax.ws.rs.{Path, PathParam}
import org.apache.spark.status.api.v1.UIRootFromServletContext import org.apache.spark.status.api.v1.ApiRequestContext
@Path("/v1") @Path("/v1")
private[v1] class ApiStreamingApp extends UIRootFromServletContext { private[v1] class ApiStreamingApp extends ApiRequestContext {
@Path("applications/{appId}/streaming") @Path("applications/{appId}/streaming")
def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = {
uiRoot.withSparkUI(appId, None) { ui => withSparkUI(appId, None) { ui =>
new ApiStreamingRootResource(ui) new ApiStreamingRootResource(ui)
} }
} }
@ -35,7 +35,7 @@ private[v1] class ApiStreamingApp extends UIRootFromServletContext {
def getStreamingRoot( def getStreamingRoot(
@PathParam("appId") appId: String, @PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = {
uiRoot.withSparkUI(appId, Some(attemptId)) { ui => withSparkUI(appId, Some(attemptId)) { ui =>
new ApiStreamingRootResource(ui) new ApiStreamingRootResource(ui)
} }
} }