Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to handle trailing slash as part of a multi part dsl route #148

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.github.smiley4.ktorswaggerui.builder.route

import io.github.smiley4.ktorswaggerui.data.PluginConfigData
import io.github.smiley4.ktorswaggerui.dsl.routing.DocumentedRouteSelector
import io.github.smiley4.ktorswaggerui.dsl.routes.OpenApiRoute
import io.github.smiley4.ktorswaggerui.dsl.routing.DocumentedRouteSelector
import io.ktor.http.HttpMethod
import io.ktor.server.auth.AuthenticationRouteSelector
import io.ktor.server.routing.ConstantParameterRouteSelector
Expand Down Expand Up @@ -45,7 +45,8 @@ class RouteCollector(
private fun getDocumentation(route: Route, base: OpenApiRoute): OpenApiRoute {
var documentation = base
if (route.selector is DocumentedRouteSelector) {
documentation = routeDocumentationMerger.merge(documentation, (route.selector as DocumentedRouteSelector).documentation)
documentation =
routeDocumentationMerger.merge(documentation, (route.selector as DocumentedRouteSelector).documentation)
}
return if (route.parent != null) {
getDocumentation(route.parent!!, documentation)
Expand All @@ -61,13 +62,13 @@ class RouteCollector(


@Suppress("CyclomaticComplexMethod")
private fun getPath(route: Route, config: PluginConfigData): String {
internal fun getPath(route: Route, config: PluginConfigData): String {
val selector = route.selector
return if (isIgnoredSelector(selector, config)) {
route.parent?.let { getPath(it, config) } ?: ""
} else {
when (route.selector) {
is TrailingSlashRouteSelector -> "/"
is TrailingSlashRouteSelector -> route.parent?.let { getPath(it, config) } ?: "/"
is RootRouteSelector -> ""
is DocumentedRouteSelector -> route.parent?.let { getPath(it, config) } ?: ""
is HttpMethodRouteSelector -> route.parent?.let { getPath(it, config) } ?: ""
Expand All @@ -77,9 +78,14 @@ class RouteCollector(
is OptionalParameterRouteSelector -> route.parent?.let { getPath(it, config) } ?: ""
else -> (route.parent?.let { getPath(it, config) } ?: "") + "/" + route.selector.toString()
}
}
}.dropTrailingSlash()
}

private fun String.dropTrailingSlash(): String = if (length > 1 && endsWith("/")) {
dropLast(1).dropTrailingSlash()
} else {
this
}

private fun isIgnoredSelector(selector: RouteSelector, config: PluginConfigData): Boolean {
return when (selector) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.github.smiley4.ktorswaggerui.builder.route

import io.github.smiley4.ktorswaggerui.data.PluginConfigData
import io.github.smiley4.ktorswaggerui.dsl.routes.OpenApiRoute
import io.github.smiley4.ktorswaggerui.dsl.routing.DocumentedRouteSelector
import io.github.smiley4.ktorswaggerui.dsl.routing.get
import io.ktor.server.application.call
import io.ktor.server.response.respond
import io.ktor.server.routing.Route
import io.ktor.server.routing.route
import org.hamcrest.CoreMatchers.equalTo
import org.hamcrest.MatcherAssert.assertThat
import org.junit.jupiter.api.Test

class RouteCollectorTest {

private val tests: List<Pair<List<String>, String>> = listOf(
listOf("/") to "",
listOf("/api") to "/api",
listOf("/nested", "/routing") to "/nested/routing",
listOf("/trailing/", "/slashes") to "/trailing/slashes",
)

@Test
fun `should be able to get nested route`() {
// Given
for ((path, expected) in tests) {
val rootR = Route(null, DocumentedRouteSelector(OpenApiRoute()), true)
var parent = rootR
for (p in path) {
parent = parent.route(p) {}
}
parent = parent.get("/") {
call.respond("Hello")
}
val routeCollector = RouteCollector(RouteDocumentationMerger())

// When
val nestedRoute = routeCollector.getPath(parent, PluginConfigData.DEFAULT)

// Then
assertThat(nestedRoute, equalTo(expected))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@ import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldNotBeEmpty
import io.ktor.client.HttpClient
import io.ktor.client.request.get
import io.ktor.client.request.header
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.path
import io.ktor.server.application.call
import io.ktor.server.auth.Authentication
import io.ktor.server.auth.UserIdPrincipal
import io.ktor.server.auth.authenticate
import io.ktor.server.auth.basic
import io.ktor.server.response.respondText
import io.ktor.server.routing.Route
import io.ktor.server.routing.Routing
import io.ktor.server.routing.get
import io.ktor.server.routing.route
import io.ktor.server.testing.testApplication
Expand Down Expand Up @@ -50,6 +58,10 @@ class RoutingTests {
it.contentType shouldBe ContentType.Application.Json
it.body.shouldNotBeEmpty()
}
get("/level1/level2/hello", auth = true).also {
it.status shouldBe HttpStatusCode.OK
it.body shouldBe "Hello Nested"
}
}

private fun swaggerUITestApplication(block: suspend TestContext.() -> Unit) {
Expand All @@ -58,6 +70,17 @@ class RoutingTests {
this.followRedirects = followRedirects
}
install(SwaggerUI)
install(Authentication) {
basic("test") {
validate { credentials ->
if (credentials.name == "test" && credentials.password == "test") {
UserIdPrincipal(credentials.name)
} else {
null
}
}
}
}
routing {
route("api.json") {
openApiSpec()
Expand All @@ -68,15 +91,40 @@ class RoutingTests {
get("hello") {
call.respondText("Hello Test")
}
nestedPath()
}
TestContext(client).apply { block() }
}
}

private fun Routing.nestedPath() {
route("/level1/"){
nestedPath()
}
}
private fun Route.nestedPath() {
route("/level2/"){
nestedHelloController()
}
}

private fun Route.nestedHelloController() {
authenticate("test") {
get("/hello") {
call.respondText("Hello Nested")
}
}
}

class TestContext(private val client: HttpClient) {

suspend fun get(path: String): GetResult {
return client.get(path)
suspend fun get(path: String, auth: Boolean = false): GetResult {
return (if (auth) {
client.get {
url.path(path)
header("Authorization", "Basic dGVzdDp0ZXN0")
}
} else client.get(path))
.let {
GetResult(
path = path,
Expand Down