diff --git a/build.gradle b/build.gradle index 4e61045..d8e7d7d 100644 --- a/build.gradle +++ b/build.gradle @@ -89,12 +89,14 @@ dependencies { implementation 'com.google.code.gson:gson:2.8.9' implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.5.0' + implementation 'com.squareup.okhttp3:okhttp:4.10.0' testImplementation 'junit:junit:4.12' testImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.5.0' testImplementation "io.mockk:mockk:1.12.0" testImplementation 'com.github.tomakehurst:wiremock:2.27.2' testImplementation "org.slf4j:slf4j-simple:1.8.0-beta4" + testImplementation("com.squareup.okhttp3:mockwebserver:4.10.0") } artifacts { diff --git a/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt b/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt index 145b7ac..22fe857 100644 --- a/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt +++ b/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt @@ -2,9 +2,11 @@ package com.statsig.androidsdk import com.google.gson.Gson import kotlinx.coroutines.CoroutineExceptionHandler -import java.io.DataOutputStream +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody +import okhttp3.RequestBody.Companion.toRequestBody import java.lang.RuntimeException -import java.net.HttpURLConnection import java.net.URL import kotlin.math.floor @@ -107,15 +109,19 @@ internal class ErrorBoundary() { ) val postData = Gson().toJson(body) - val conn = url.openConnection() as HttpURLConnection - conn.requestMethod = "POST" - conn.doOutput = true - conn.setRequestProperty("Content-Type", "application/json") - conn.setRequestProperty("STATSIG-API-KEY", apiKey) - conn.useCaches = false + val clientBuilder = OkHttpClient.Builder() - DataOutputStream(conn.outputStream).use { it.writeBytes(postData) } - conn.responseCode // triggers request + clientBuilder.addInterceptor(RequestHeaderInterceptor(apiKey!!)) + clientBuilder.addInterceptor(ResponseInterceptor()) + + val httpClient = clientBuilder.build() + + val requestBody: RequestBody = postData.toRequestBody(JSON) + val request: Request = Request.Builder() + .url(url) + .post(requestBody) + .build() + httpClient.newCall(request).execute() } catch (e: Exception) { // noop } diff --git a/src/main/java/com/statsig/androidsdk/HttpUtils.kt b/src/main/java/com/statsig/androidsdk/HttpUtils.kt new file mode 100644 index 0000000..51cc08c --- /dev/null +++ b/src/main/java/com/statsig/androidsdk/HttpUtils.kt @@ -0,0 +1,73 @@ +package com.statsig.androidsdk + +import okhttp3.Interceptor +import okhttp3.MediaType +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.Response +import okhttp3.ResponseBody.Companion.toResponseBody +import java.net.HttpURLConnection + + +private val RETRY_CODES: IntArray = intArrayOf( + HttpURLConnection.HTTP_CLIENT_TIMEOUT, + HttpURLConnection.HTTP_INTERNAL_ERROR, + HttpURLConnection.HTTP_BAD_GATEWAY, + HttpURLConnection.HTTP_UNAVAILABLE, + HttpURLConnection.HTTP_GATEWAY_TIMEOUT, + 522, + 524, + 599, +) +private const val CONTENT_TYPE_HEADER_KEY = "Content-Type" +private const val CONTENT_TYPE_HEADER_VALUE = "application/json; charset=UTF-8" +private const val STATSIG_API_HEADER_KEY = "STATSIG-API-KEY" +private const val STATSIG_CLIENT_TIME_HEADER_KEY = "STATSIG-CLIENT-TIME" +private const val STATSIG_SDK_TYPE_KEY = "STATSIG-SDK-TYPE" +private const val STATSIG_SDK_VERSION_KEY = "STATSIG-SDK-VERSION" +private const val ACCEPT_HEADER_KEY = "Accept" +private const val ACCEPT_HEADER_VALUE = "application/json" +internal val JSON: MediaType = "application/json; charset=utf-8".toMediaType(); + +class RequestHeaderInterceptor(private val sdkKey: String) : Interceptor { + @Throws(Exception::class) + override fun intercept(chain: Interceptor.Chain): Response { + val original = chain.request() + val request = original.newBuilder() + .addHeader(CONTENT_TYPE_HEADER_KEY, CONTENT_TYPE_HEADER_VALUE) + .addHeader(STATSIG_API_HEADER_KEY, sdkKey) + .addHeader(STATSIG_SDK_TYPE_KEY, "android-client") + .addHeader(STATSIG_SDK_VERSION_KEY, BuildConfig.VERSION_NAME) + .addHeader(STATSIG_CLIENT_TIME_HEADER_KEY, System.currentTimeMillis().toString()) + .addHeader(ACCEPT_HEADER_KEY, ACCEPT_HEADER_VALUE) + .method(original.method, original.body) + .build() + return chain.proceed(request) + } +} + +class ResponseInterceptor : Interceptor { + @Throws(Exception::class) + override fun intercept(chain: Interceptor.Chain): Response { + val request = chain.request() + var response = chain.proceed(request) + + var attempt = 1 + var retries = 0 + if (LOGGING_ENDPOINT in request.url.pathSegments) { + retries = 3 + } + while (!response.isSuccessful && attempt <= retries && response.code in RETRY_CODES) { + attempt++ + + response.close() + response = chain.proceed(request) + } + + val bodyString = response.body?.string() + + return response.newBuilder() + .body(bodyString?.toResponseBody(response.body?.contentType())) + .addHeader("attempt", attempt.toString()) + .build() + } +} \ No newline at end of file diff --git a/src/main/java/com/statsig/androidsdk/InitializeResponse.kt b/src/main/java/com/statsig/androidsdk/InitializeResponse.kt index 415f4ac..d30ad15 100644 --- a/src/main/java/com/statsig/androidsdk/InitializeResponse.kt +++ b/src/main/java/com/statsig/androidsdk/InitializeResponse.kt @@ -1,7 +1,9 @@ package com.statsig.androidsdk +import com.google.gson.JsonObject +import com.google.gson.JsonSerializer import com.google.gson.annotations.SerializedName -import java.lang.Exception + enum class InitializeFailReason { CoroutineTimeout, @@ -33,7 +35,7 @@ internal data class APIFeatureGate( @SerializedName("secondary_exposures") val secondaryExposures: Array> = arrayOf(), ) -internal data class APIDynamicConfig( +internal data class APIDynamicConfig ( @SerializedName("name") val name: String, @SerializedName("value") val value: Map, @SerializedName("rule_id") val ruleID: String?, diff --git a/src/main/java/com/statsig/androidsdk/Layer.kt b/src/main/java/com/statsig/androidsdk/Layer.kt index 25b090f..ce58847 100644 --- a/src/main/java/com/statsig/androidsdk/Layer.kt +++ b/src/main/java/com/statsig/androidsdk/Layer.kt @@ -6,7 +6,7 @@ package com.statsig.androidsdk class Layer internal constructor( private val client: StatsigClient?, private val name: String, - private val jsonValue: Map, + public val jsonValue: Map, private val rule: String, private val details: EvaluationDetails, private val secondaryExposures: Array> = arrayOf(), diff --git a/src/main/java/com/statsig/androidsdk/StatsigClient.kt b/src/main/java/com/statsig/androidsdk/StatsigClient.kt index 30b069f..0920401 100644 --- a/src/main/java/com/statsig/androidsdk/StatsigClient.kt +++ b/src/main/java/com/statsig/androidsdk/StatsigClient.kt @@ -29,7 +29,6 @@ internal class StatsigClient() { private lateinit var user: StatsigUser private lateinit var application: Application private lateinit var sdkKey: String - private lateinit var options: StatsigOptions private lateinit var lifecycleListener: StatsigActivityLifecycleListener private lateinit var logger: StatsigLogger private lateinit var statsigMetadata: StatsigMetadata @@ -44,7 +43,9 @@ internal class StatsigClient() { private val isBootstrapped = AtomicBoolean(false) @VisibleForTesting - internal var statsigNetwork: StatsigNetwork = StatsigNetwork() + internal lateinit var statsigNetwork: StatsigNetwork + @VisibleForTesting + internal lateinit var options: StatsigOptions fun initializeAsync( application: Application, @@ -98,7 +99,6 @@ internal class StatsigClient() { } val initResponse = statsigNetwork.initialize( this@StatsigClient.options.api, - this@StatsigClient.sdkKey, user, this@StatsigClient.store.getLastUpdateTime(this@StatsigClient.user), this@StatsigClient.statsigMetadata, @@ -118,7 +118,7 @@ internal class StatsigClient() { this@StatsigClient.pollForUpdates() - this@StatsigClient.statsigNetwork.apiRetryFailedLogs(this@StatsigClient.options.api, this@StatsigClient.sdkKey) + this@StatsigClient.statsigNetwork.apiRetryFailedLogs(this@StatsigClient.options.api) this@StatsigClient.diagnostics.markEnd(KeyType.OVERALL, success) logger.logDiagnostics() InitializationDetails(duration, success, if (initResponse is InitializeResponse.FailedInitializeResponse) initResponse else null) @@ -157,6 +157,11 @@ internal class StatsigClient() { exceptionHandler = Statsig.errorBoundary.getExceptionHandler() statsigScope = CoroutineScope(statsigJob + dispatcherProvider.main + exceptionHandler) + // Prevent overwriting mocked network in tests + if (!this::statsigNetwork.isInitialized) { + statsigNetwork = StatsigNetwork(sdkKey) + } + lifecycleListener = StatsigActivityLifecycleListener() application.registerActivityLifecycleCallbacks(lifecycleListener) logger = StatsigLogger( @@ -408,7 +413,6 @@ internal class StatsigClient() { val initResponse = statsigNetwork.initialize( options.api, - sdkKey, this@StatsigClient.user, sinceTime, statsigMetadata, @@ -569,7 +573,7 @@ internal class StatsigClient() { } pollingJob?.cancel() val sinceTime = store.getLastUpdateTime(user) - pollingJob = statsigNetwork.pollForChanges(options.api, sdkKey, user, sinceTime, statsigMetadata).onEach { + pollingJob = statsigNetwork.pollForChanges(options.api, user, sinceTime, statsigMetadata).onEach { if (it?.hasUpdates == true) { store.save(it, user) } diff --git a/src/main/java/com/statsig/androidsdk/StatsigLogger.kt b/src/main/java/com/statsig/androidsdk/StatsigLogger.kt index 4734216..e19ee8b 100644 --- a/src/main/java/com/statsig/androidsdk/StatsigLogger.kt +++ b/src/main/java/com/statsig/androidsdk/StatsigLogger.kt @@ -70,7 +70,7 @@ internal class StatsigLogger( } val flushEvents = ArrayList(events) events = ConcurrentLinkedQueue() - statsigNetwork.apiPostLogs(api, sdkKey, gson.toJson(LogEventData(flushEvents, statsigMetadata))) + statsigNetwork.apiPostLogs(api, gson.toJson(LogEventData(flushEvents, statsigMetadata))) } } diff --git a/src/main/java/com/statsig/androidsdk/StatsigNetwork.kt b/src/main/java/com/statsig/androidsdk/StatsigNetwork.kt index 42a85ca..0ae66c9 100644 --- a/src/main/java/com/statsig/androidsdk/StatsigNetwork.kt +++ b/src/main/java/com/statsig/androidsdk/StatsigNetwork.kt @@ -1,30 +1,21 @@ package com.statsig.androidsdk import android.content.SharedPreferences +import com.google.gson.annotations.SerializedName import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.isActive import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody +import okhttp3.RequestBody.Companion.toRequestBody import java.net.ConnectException -import java.net.HttpURLConnection import java.net.SocketTimeoutException -import java.net.URL import java.util.concurrent.TimeUnit -import kotlin.math.pow -private val RETRY_CODES: IntArray = intArrayOf( - HttpURLConnection.HTTP_CLIENT_TIMEOUT, - HttpURLConnection.HTTP_INTERNAL_ERROR, - HttpURLConnection.HTTP_BAD_GATEWAY, - HttpURLConnection.HTTP_UNAVAILABLE, - HttpURLConnection.HTTP_GATEWAY_TIMEOUT, - 522, - 524, - 599, -) // Constants private val MAX_LOG_PERIOD = TimeUnit.DAYS.toMillis(3) @@ -41,94 +32,67 @@ private const val HASH = "hash" private const val OFFLINE_LOGS_KEY: String = "StatsigNetwork.OFFLINE_LOGS" // Endpoints -private const val LOGGING_ENDPOINT: String = "log_event" -private const val INITIALIZE_ENDPOINT: String = "initialize" - -// HTTP -private const val POST = "POST" -private const val CONTENT_TYPE_HEADER_KEY = "Content-Type" -private const val CONTENT_TYPE_HEADER_VALUE = "application/json; charset=UTF-8" -private const val STATSIG_API_HEADER_KEY = "STATSIG-API-KEY" -private const val STATSIG_CLIENT_TIME_HEADER_KEY = "STATSIG-CLIENT-TIME" -private const val STATSIG_SDK_TYPE_KEY = "STATSIG-SDK-TYPE" -private const val STATSIG_SDK_VERSION_KEY = "STATSIG-SDK-VERSION" -private const val ACCEPT_HEADER_KEY = "Accept" -private const val ACCEPT_HEADER_VALUE = "application/json" - -internal interface StatsigNetwork { - - suspend fun initialize( - api: String, - sdkKey: String, - user: StatsigUser?, - sinceTime: Long?, - metadata: StatsigMetadata, - initTimeoutMs: Long, - sharedPrefs: SharedPreferences, - diagnostics: Diagnostics? = null, - hashUsed: HashAlgorithm, - ): InitializeResponse? - - fun pollForChanges( - api: String, - sdkKey: String, - user: StatsigUser?, - sinceTime: Long?, - metadata: StatsigMetadata, - ): Flow - - suspend fun apiPostLogs(api: String, sdkKey: String, bodyString: String) - - suspend fun apiRetryFailedLogs(api: String, sdkKey: String) - - suspend fun addFailedLogRequest(requestBody: String) -} - -internal fun StatsigNetwork(): StatsigNetwork = StatsigNetworkImpl() - -private class StatsigNetworkImpl : StatsigNetwork { +internal const val LOGGING_ENDPOINT: String = "log_event" +internal const val INITIALIZE_ENDPOINT: String = "initialize" + +internal data class InitializeRequestBody( + @SerializedName("user") val user: StatsigUser?, + @SerializedName("statsigMetadata") val statsigMetadata: StatsigMetadata, + @SerializedName("sinceTime") val sinceTime: Long?, + @SerializedName("hash") val hash: HashAlgorithm, +) +internal class StatsigNetwork( + sdkKey: String, +) { private val gson = StatsigUtil.getGson() private val dispatcherProvider = CoroutineDispatcherProvider() private var sharedPrefs: SharedPreferences? = null + private val httpClient: OkHttpClient - override suspend fun initialize( + init { + val clientBuilder = OkHttpClient.Builder() + + clientBuilder.addInterceptor(RequestHeaderInterceptor(sdkKey)) + clientBuilder.addInterceptor(ResponseInterceptor()) + + httpClient = clientBuilder.build() + } + + suspend fun initialize( api: String, - sdkKey: String, user: StatsigUser?, sinceTime: Long?, metadata: StatsigMetadata, initTimeoutMs: Long, sharedPrefs: SharedPreferences, - diagnostics: Diagnostics?, + diagnostics: Diagnostics? = null, hashUsed: HashAlgorithm, ): InitializeResponse { this.sharedPrefs = sharedPrefs if (initTimeoutMs == 0L) { - return initializeImpl(api, sdkKey, user, sinceTime, metadata, diagnostics, hashUsed = hashUsed) + return initializeImpl(api, user, sinceTime, metadata, diagnostics, hashUsed = hashUsed) } return withTimeout(initTimeoutMs) { - initializeImpl(api, sdkKey, user, sinceTime, metadata, diagnostics, initTimeoutMs.toInt(), hashUsed = hashUsed) + initializeImpl(api, user, sinceTime, metadata, diagnostics, initTimeoutMs, hashUsed = hashUsed) } } private suspend fun initializeImpl( api: String, - sdkKey: String, user: StatsigUser?, sinceTime: Long?, metadata: StatsigMetadata, diagnostics: Diagnostics?, - timeoutMs: Int? = null, + timeoutMs: Long? = null, hashUsed: HashAlgorithm, ): InitializeResponse { - val retries = 0 return try { val userCopy = user?.getCopyForEvaluation() val metadataCopy = metadata.copy() - val body = mapOf(USER to userCopy, STATSIG_METADATA to metadataCopy, SINCE_TIME to sinceTime, HASH to hashUsed) + val body = InitializeRequestBody(userCopy, metadataCopy, sinceTime, hashUsed) var statusCode: Int? = null - val response = postRequest(api, INITIALIZE_ENDPOINT, sdkKey, gson.toJson(body), retries, ContextType.INITIALIZE, diagnostics, timeoutMs) { status: Int? -> statusCode = status } + val response = postRequest(api, INITIALIZE_ENDPOINT, gson.toJson(body), ContextType.INITIALIZE, diagnostics, timeoutMs) { status: Int? -> statusCode = status } response ?: InitializeResponse.FailedInitializeResponse(InitializeFailReason.NetworkError, null, statusCode) } catch (e: Exception) { Statsig.errorBoundary.logException(e) @@ -147,9 +111,8 @@ private class StatsigNetworkImpl : StatsigNetwork { } } - override fun pollForChanges( + fun pollForChanges( api: String, - sdkKey: String, user: StatsigUser?, sinceTime: Long?, metadata: StatsigMetadata, @@ -168,28 +131,28 @@ private class StatsigNetworkImpl : StatsigNetwork { HASH to HashAlgorithm.DJB2.value, ) try { - emit(postRequest(api, INITIALIZE_ENDPOINT, sdkKey, gson.toJson(body), 0, ContextType.CONFIG_SYNC)) + emit(postRequest(api, INITIALIZE_ENDPOINT, gson.toJson(body), ContextType.CONFIG_SYNC)) } catch (_: Exception) {} } } } - override suspend fun apiPostLogs(api: String, sdkKey: String, bodyString: String) { + suspend fun apiPostLogs(api: String, bodyString: String) { try { - postRequest(api, LOGGING_ENDPOINT, sdkKey, bodyString, 3, ContextType.EVENT_LOGGING) + postRequest(api, LOGGING_ENDPOINT, bodyString, ContextType.EVENT_LOGGING) } catch (_: Exception) {} } - override suspend fun apiRetryFailedLogs(api: String, sdkKey: String) { + suspend fun apiRetryFailedLogs(api: String) { val savedLogs = getSavedLogs() if (savedLogs.isEmpty()) { return } StatsigUtil.removeFromSharedPrefs(sharedPrefs, OFFLINE_LOGS_KEY) - savedLogs.map { apiPostLogs(api, sdkKey, it.requestBody) } + savedLogs.map { apiPostLogs(api, it.requestBody) } } - override suspend fun addFailedLogRequest(requestBody: String) { + suspend fun addFailedLogRequest(requestBody: String) { withContext(dispatcherProvider.io) { val savedLogs = getSavedLogs() + StatsigOfflineRequest(System.currentTimeMillis(), requestBody) try { @@ -226,82 +189,57 @@ private class StatsigNetworkImpl : StatsigNetwork { private suspend inline fun postRequest( api: String, endpoint: String, - sdkKey: String, bodyString: String, - retries: Int, contextType: ContextType, diagnostics: Diagnostics? = null, - timeout: Int? = null, + timeout: Long? = null, crossinline callback: ((statusCode: Int?) -> Unit) = { _: Int? -> }, ): T? { return withContext(dispatcherProvider.io) { // Perform network calls in IO thread - var retryAttempt = 1 - while (isActive) { - val url = if (api.endsWith("/")) "$api$endpoint" else "$api/$endpoint" - val connection: HttpURLConnection = URL(url).openConnection() as HttpURLConnection + val url = if (api.endsWith("/")) "$api$endpoint" else "$api/$endpoint" + diagnostics?.markStart(KeyType.INITIALIZE, StepType.NETWORK_REQUEST, Marker(attempt = 1), contextType) + try { + var client = httpClient + val requestBody: RequestBody = bodyString.toRequestBody(JSON) + val request: Request = Request.Builder() + .url(url) + .post(requestBody) + .build() - connection.requestMethod = POST if (timeout != null) { - connection.connectTimeout = timeout - connection.readTimeout = timeout + client = httpClient.newBuilder() + .callTimeout(timeout, TimeUnit.MILLISECONDS) + .build() } - connection.setRequestProperty(CONTENT_TYPE_HEADER_KEY, CONTENT_TYPE_HEADER_VALUE) - connection.setRequestProperty(STATSIG_API_HEADER_KEY, sdkKey) - connection.setRequestProperty(STATSIG_SDK_TYPE_KEY, "android-client") - connection.setRequestProperty(STATSIG_SDK_VERSION_KEY, BuildConfig.VERSION_NAME) - connection.setRequestProperty(STATSIG_CLIENT_TIME_HEADER_KEY, System.currentTimeMillis().toString()) - connection.setRequestProperty(ACCEPT_HEADER_KEY, ACCEPT_HEADER_VALUE) - diagnostics?.markStart(KeyType.INITIALIZE, StepType.NETWORK_REQUEST, Marker(attempt = retryAttempt), contextType) - try { - connection.outputStream.bufferedWriter(Charsets.UTF_8) - .use { it.write(bodyString) } - val code = connection.responseCode - val inputStream = if (code < HttpURLConnection.HTTP_BAD_REQUEST) { - connection.inputStream - } else { - connection.errorStream - } - endDiagnostics(diagnostics, contextType, code, connection.headerFields["x-statsig-region"]?.get(0), retryAttempt) - when (code) { - in 200..299 -> { - if (code == 204 && endpoint == INITIALIZE_ENDPOINT) { - return@withContext gson.fromJson("{has_updates: false}", T::class.java) - } - return@withContext inputStream.bufferedReader(Charsets.UTF_8) - .use { gson.fromJson(it, T::class.java) } - } - in RETRY_CODES -> { - if (retries > 0 && retryAttempt++ < retries) { - // Don't return, just allow the loop to happen - delay(100.0.pow(retryAttempt + 1).toLong()) - } else if (endpoint == LOGGING_ENDPOINT) { - addFailedLogRequest(bodyString) - callback(code) - return@withContext null - } else { - callback(code) - return@withContext null - } - } - else -> { - callback(code) - return@withContext null + var response = client.newCall(request).execute() + var code = response.code + endDiagnostics(diagnostics, contextType, code, + response.headers["x-statsig-region"], response.headers["attempt"]?.toInt()) + when (code) { + in 200..299 -> { + if (code == 204 && endpoint == INITIALIZE_ENDPOINT) { + return@withContext gson.fromJson("{has_updates: false}", T::class.java) } + return@withContext gson.fromJson(response.body?.string(), T::class.java) } - } catch (e: Exception) { - if (endpoint == LOGGING_ENDPOINT) { - addFailedLogRequest(bodyString) + else -> { + if (endpoint == LOGGING_ENDPOINT) { + addFailedLogRequest(bodyString) + } + callback(code) + return@withContext null } - throw e - } finally { - connection.disconnect() } + } catch (e: Exception) { + if (endpoint == LOGGING_ENDPOINT) { + addFailedLogRequest(bodyString) + } + throw e } - return@withContext null } } - private fun endDiagnostics(diagnostics: Diagnostics?, diagnosticsContext: ContextType, statusCode: Int, sdkRegion: String?, attempt: Int) { + private fun endDiagnostics(diagnostics: Diagnostics?, diagnosticsContext: ContextType, statusCode: Int, sdkRegion: String?, attempt: Int?) { if (diagnostics == null) { return } diff --git a/src/test/java/com/statsig/androidsdk/AsyncInitVsUpdateTest.kt b/src/test/java/com/statsig/androidsdk/AsyncInitVsUpdateTest.kt index 88c082a..7dca9bb 100644 --- a/src/test/java/com/statsig/androidsdk/AsyncInitVsUpdateTest.kt +++ b/src/test/java/com/statsig/androidsdk/AsyncInitVsUpdateTest.kt @@ -6,6 +6,7 @@ import io.mockk.coEvery import io.mockk.mockk import io.mockk.spyk import kotlinx.coroutines.* +import okhttp3.mockwebserver.MockWebServer import org.junit.Assert.assertEquals import org.junit.Before import org.junit.Test @@ -45,11 +46,12 @@ class AsyncInitVsUpdateTest { testSharedPrefs = TestUtil.stubAppFunctions(app) TestUtil.mockStatsigUtil() - val network = TestUtil.mockNetwork() + // Cannot use mockWebServer to mimic network delay. Must mock StatsigNetwork directly. + var network = mockk() coEvery { - network.initialize(any(), any(), any(), any(), any(), any(), any(), any(), any()) + network.initialize(any(), any(), any(), any(), any(), any(), any(), any()) } coAnswers { - val user = thirdArg() + val user = secondArg() getResponseForUser(user) } Statsig.client = spyk() @@ -89,8 +91,14 @@ class AsyncInitVsUpdateTest { } } - Statsig.initializeAsync(app, "client-key", userA, callback) - Statsig.updateUserAsync(userB, callback) + var elapsed = kotlin.system.measureTimeMillis { + Statsig.initializeAsync(app, "client-key", userA, callback) + } + print("initializeAsync $elapsed") + elapsed = kotlin.system.measureTimeMillis { + Statsig.updateUserAsync(userB, callback) + } + print("updateUserAsync $elapsed") // Since updateUserAsync has been called, we void values for user_a var config = Statsig.getConfig("a_config") diff --git a/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt b/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt index cb6a049..a27eeae 100644 --- a/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt +++ b/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt @@ -3,6 +3,7 @@ package com.statsig.androidsdk import android.app.Application import com.github.tomakehurst.wiremock.client.WireMock.* import com.github.tomakehurst.wiremock.junit.WireMockRule +import io.mockk.coEvery import io.mockk.mockk import io.mockk.unmockkAll import kotlinx.coroutines.runBlocking @@ -28,9 +29,8 @@ class ErrorBoundaryTest { app = mockk() TestUtil.mockDispatchers() TestUtil.stubAppFunctions(app) - val network = TestUtil.mockBrokenNetwork() Statsig.client = StatsigClient() - Statsig.client.statsigNetwork = network + TestUtil.mockBrokenServer() Statsig.errorBoundary = boundary } @@ -92,7 +92,18 @@ class ErrorBoundaryTest { fun testInitializeIsCaptured() { try { runBlocking { - Statsig.client.initialize(app, "client-key", null) + Statsig.client.statsigNetwork = mockk() + coEvery { + Statsig.client.statsigNetwork.initialize(any(), any(), any(), any(), any(), any(), any(), any()) + } answers { + throw IOException("Example exception in StatsigNetwork initialize") + } + Statsig.client.initialize( + app, + "client-key", + null, + options = StatsigOptions(disableDiagnosticsLogging = true) + ) Statsig.shutdown() } } catch (e: Throwable) { @@ -111,7 +122,18 @@ class ErrorBoundaryTest { fun testInitializeAsyncIsCaptured() { try { runBlocking { - Statsig.client.initializeAsync(app, "client-key", null) + Statsig.client.statsigNetwork = mockk() + coEvery { + Statsig.client.statsigNetwork.initialize(any(), any(), any(), any(), any(), any(), any(), any()) + } answers { + throw IOException("Example exception in StatsigNetwork initialize") + } + Statsig.client.initializeAsync( + app, + "client-key", + null, + options = StatsigOptions(disableDiagnosticsLogging = true) + ) Statsig.shutdown() } } catch (e: Throwable) { diff --git a/src/test/java/com/statsig/androidsdk/LayerConfigTest.kt b/src/test/java/com/statsig/androidsdk/LayerConfigTest.kt index 92af137..9318d83 100644 --- a/src/test/java/com/statsig/androidsdk/LayerConfigTest.kt +++ b/src/test/java/com/statsig/androidsdk/LayerConfigTest.kt @@ -320,11 +320,7 @@ class LayerConfigTest { } private fun initClient() = runBlocking { - val statsigNetwork = TestUtil.mockNetwork() - - client = StatsigClient() - client.statsigNetwork = statsigNetwork - + client = TestUtil.mockClientWithServer(client, TestUtil.mockServer()) client.initialize(app, "test-key") } } diff --git a/src/test/java/com/statsig/androidsdk/LayerExposureTest.kt b/src/test/java/com/statsig/androidsdk/LayerExposureTest.kt index a7430a3..a87a288 100644 --- a/src/test/java/com/statsig/androidsdk/LayerExposureTest.kt +++ b/src/test/java/com/statsig/androidsdk/LayerExposureTest.kt @@ -310,16 +310,15 @@ class LayerExposureTest { } private fun start(layers: Map, user: StatsigUser = StatsigUser(userID = "jkw")) { - val network = TestUtil.mockNetwork( + val server = TestUtil.mockServer( layerConfigs = layers, + onLog = { result -> + // filter out diagnostics data + val events = result.events.filter { event -> event.eventName != "statsig::diagnostics" } + logs = if (events.isEmpty())null else LogEventData(events as ArrayList, result.statsigMetadata) + } ) - - TestUtil.captureLogs(network) { result -> - // filter out diagnostics data - val events = result.events.filter { event -> event.eventName != "statsig::diagnostics" } - logs = if (events.isEmpty())null else LogEventData(events as ArrayList, result.statsigMetadata) - } initTime = System.currentTimeMillis() - TestUtil.startStatsigAndWait(app, user = user, network = network) + TestUtil.startStatsigAndWait(app, user = user, server = server) } } diff --git a/src/test/java/com/statsig/androidsdk/StatsigCacheTest.kt b/src/test/java/com/statsig/androidsdk/StatsigCacheTest.kt index 7d41e79..3665355 100644 --- a/src/test/java/com/statsig/androidsdk/StatsigCacheTest.kt +++ b/src/test/java/com/statsig/androidsdk/StatsigCacheTest.kt @@ -58,7 +58,7 @@ class StatsigCacheTest { assertTrue(client.checkGate("always_on")) runBlocking { - client.statsigNetwork.apiRetryFailedLogs("https://statsigapi.net/v1", "client-test") + client.statsigNetwork.apiRetryFailedLogs("https://statsigapi.net/v1") client.statsigNetwork.addFailedLogRequest("{}") } val config = client.getConfig("test_config") @@ -81,11 +81,11 @@ class StatsigCacheTest { testSharedPrefs.edit().putString("Statsig.CACHE_BY_USER", gson.toJson(cacheById)).apply() TestUtil.startStatsigAndDontWait(app, user, StatsigOptions(loadCacheAsync = true)) + TestUtil.mockServer() client = Statsig.client - client.statsigNetwork = TestUtil.mockNetwork() assertFalse(client.checkGate("always_on")) runBlocking { - client.statsigNetwork.apiRetryFailedLogs("https://statsigapi.net/v1", "client-test") + client.statsigNetwork.apiRetryFailedLogs("https://statsigapi.net/v1") client.statsigNetwork.addFailedLogRequest("{}") } val config = client.getConfig("test_config") diff --git a/src/test/java/com/statsig/androidsdk/StatsigFromJavaTest.java b/src/test/java/com/statsig/androidsdk/StatsigFromJavaTest.java index 6110675..4fa8447 100644 --- a/src/test/java/com/statsig/androidsdk/StatsigFromJavaTest.java +++ b/src/test/java/com/statsig/androidsdk/StatsigFromJavaTest.java @@ -16,6 +16,7 @@ import kotlin.Unit; import kotlin.jvm.functions.Function1; +import okhttp3.mockwebserver.MockWebServer; public class StatsigFromJavaTest { private Application app; @@ -142,27 +143,26 @@ public void testLogging() { } private void start() { - StatsigNetwork network = TestUtil.Companion.mockNetwork( + MockWebServer server = TestUtil.Companion.mockServer( gates, configs, layers, null, true, + new Function1() { + @Override + public Unit invoke(LogEventData logEventData) { + logs = logEventData; + return null; + } + }, null); - TestUtil.Companion.captureLogs(network, new Function1() { - @Override - public Unit invoke(LogEventData logEventData) { - logs = logEventData; - return null; - } - }); - TestUtil.Companion.startStatsigAndWait( app, new StatsigUser("dloomb"), new StatsigOptions(), - network); + server); } private APIFeatureGate makeGate(String name, Boolean value) { diff --git a/src/test/java/com/statsig/androidsdk/StatsigOverridesTest.kt b/src/test/java/com/statsig/androidsdk/StatsigOverridesTest.kt index 9b9d610..a0100a6 100644 --- a/src/test/java/com/statsig/androidsdk/StatsigOverridesTest.kt +++ b/src/test/java/com/statsig/androidsdk/StatsigOverridesTest.kt @@ -31,10 +31,8 @@ class StatsigOverridesTest { app = mockk() TestUtil.stubAppFunctions(app) - val statsigNetwork = TestUtil.mockNetwork() - Statsig.client = StatsigClient() - Statsig.client.statsigNetwork = statsigNetwork + TestUtil.mockServer() Statsig.initialize(app, "test-key") } diff --git a/src/test/java/com/statsig/androidsdk/StatsigStickyExperimentTest.kt b/src/test/java/com/statsig/androidsdk/StatsigStickyExperimentTest.kt index d3563b1..0868bf2 100644 --- a/src/test/java/com/statsig/androidsdk/StatsigStickyExperimentTest.kt +++ b/src/test/java/com/statsig/androidsdk/StatsigStickyExperimentTest.kt @@ -242,7 +242,7 @@ class StatsigStickyExperimentTest { ) = runBlocking { TestUtil.startStatsigAndWait( app, - network = TestUtil.mockNetwork( + server = TestUtil.mockServer( featureGates = mapOf(), dynamicConfigs = configs, layerConfigs = layers, diff --git a/src/test/java/com/statsig/androidsdk/StatsigTest.kt b/src/test/java/com/statsig/androidsdk/StatsigTest.kt index 023dd02..b7f4121 100644 --- a/src/test/java/com/statsig/androidsdk/StatsigTest.kt +++ b/src/test/java/com/statsig/androidsdk/StatsigTest.kt @@ -6,6 +6,7 @@ import io.mockk.coEvery import io.mockk.mockk import io.mockk.unmockkAll import kotlinx.coroutines.runBlocking +import okhttp3.mockwebserver.MockWebServer import org.junit.After import org.junit.Assert.* import org.junit.Before @@ -17,9 +18,8 @@ class StatsigTest { private lateinit var app: Application private var flushedLogs: String = "" - private var initUser: StatsigUser? = null private var client: StatsigClient = StatsigClient() - private lateinit var network: StatsigNetwork + private lateinit var server: MockWebServer private lateinit var testSharedPrefs: TestSharedPreferences private val gson = Gson() @@ -32,15 +32,9 @@ class StatsigTest { TestUtil.mockStatsigUtil() - network = TestUtil.mockNetwork() { user -> - initUser = user - } - - coEvery { - network.apiPostLogs(any(), any(), any()) - } answers { - flushedLogs = thirdArg() - } + server = TestUtil.mockServer(onLog = { + flushedLogs = gson.toJson(it) + }) } @After @@ -70,13 +64,9 @@ class StatsigTest { val now = System.currentTimeMillis() user.customIDs = mapOf("random_id" to "abcde") - TestUtil.startStatsigAndWait(app, user, StatsigOptions(overrideStableID = "custom_stable_id"), network = network) + TestUtil.startStatsigAndWait(app, user, StatsigOptions(overrideStableID = "custom_stable_id"), server = server) client = Statsig.client - assertEquals( - Gson().toJson(initUser?.customIDs), - Gson().toJson(mapOf("random_id" to "abcde")), - ) assertTrue(client.checkGate("always_on")) assertTrue(client.checkGateWithExposureLoggingDisabled("always_on_v2")) assertFalse(client.checkGateWithExposureLoggingDisabled("a_different_gate")) @@ -247,30 +237,23 @@ class StatsigTest { } } - var user: StatsigUser? = null - Statsig.client.statsigNetwork = TestUtil.mockNetwork() { - user = it - } + TestUtil.useServer(Statsig.client, TestUtil.mockServer()) Statsig.initializeAsync(app, "client-sdkkey", StatsigUser("jkw"), callback) countdown.await(1L, TimeUnit.SECONDS) countdown = CountDownLatch(1) assertTrue(Statsig.client.isInitialized()) - assertEquals("jkw", user?.userID) Statsig.shutdown() assertFalse(Statsig.client.isInitialized()) - Statsig.client.statsigNetwork = TestUtil.mockNetwork() { - user = it - } + TestUtil.useServer(Statsig.client, TestUtil.mockServer()) Statsig.initializeAsync(app, "client-sdkkey", StatsigUser("dloomb"), callback) countdown.await(1L, TimeUnit.SECONDS) - assertEquals("dloomb", user?.userID) return@runBlocking } } diff --git a/src/test/java/com/statsig/androidsdk/StoreTest.kt b/src/test/java/com/statsig/androidsdk/StoreTest.kt index d5bc469..271ec31 100644 --- a/src/test/java/com/statsig/androidsdk/StoreTest.kt +++ b/src/test/java/com/statsig/androidsdk/StoreTest.kt @@ -107,7 +107,7 @@ class StoreTest { @Test fun testParsingNumberPrecision() = runBlocking { - var network: StatsigNetwork = TestUtil.mockNetwork( + var server = TestUtil.mockServer( dynamicConfigs = mapOf( "long!" to APIDynamicConfig( @@ -121,11 +121,11 @@ class StoreTest { time = 1621637839, hasUpdates = true, ) - TestUtil.startStatsigAndWait(app, userJkw, StatsigOptions(), network) + TestUtil.startStatsigAndWait(app, userJkw, StatsigOptions(), server) assertEquals(Long.MAX_VALUE, Statsig.getConfig("long").getLong("key", 0L)) assertEquals(Double.MIN_VALUE, Statsig.getConfig("double").getDouble("key", 0.0), 0.0) - network = TestUtil.mockBrokenNetwork() - TestUtil.startStatsigAndWait(app, userJkw, StatsigOptions(loadCacheAsync = true), network) + server = TestUtil.mockBrokenServer() + TestUtil.startStatsigAndWait(app, userJkw, StatsigOptions(loadCacheAsync = true), server = server) assertEquals(EvaluationReason.Cache, Statsig.getConfig("long").getEvaluationDetails().reason) assertEquals(Long.MAX_VALUE, Statsig.getConfig("long").getLong("key", 0L)) assertEquals(Double.MIN_VALUE, Statsig.getConfig("double").getDouble("key", 0.0), 0.0) @@ -391,7 +391,7 @@ class StoreTest { @Test fun testStoreUpdatesOnlyWithUpdatedValues() { val networkTime = 123456789L - var network: StatsigNetwork = TestUtil.mockNetwork( + var server = TestUtil.mockServer( dynamicConfigs = mapOf( "test_config!" to APIDynamicConfig( "test_config!", @@ -403,11 +403,11 @@ class StoreTest { hasUpdates = true, ) val user = StatsigUser("123") - TestUtil.startStatsigAndWait(app, user, StatsigOptions(), network) - coVerify { network.initialize(any(), any(), any(), null, any(), any(), any(), any(), any()) } + TestUtil.startStatsigAndWait(app, user, StatsigOptions(), server = server) + assertEquals(1, server.requestCount) assertEquals(networkTime, Statsig.client.getStore().getLastUpdateTime(user)) assertEquals("first", Statsig.getConfig("test_config").getString("key", "")) - network = TestUtil.mockNetwork( + TestUtil.mockServer( dynamicConfigs = mapOf( "test_config!" to APIDynamicConfig( "test_config!", @@ -418,9 +418,8 @@ class StoreTest { time = networkTime - 1, hasUpdates = false, ) - Statsig.client.statsigNetwork = network runBlocking { Statsig.updateUser(user) } - coVerify { network.initialize(any(), any(), any(), networkTime, any(), any(), any(), any(), any()) } + assertEquals(1, server.requestCount) assertEquals(networkTime, Statsig.client.getStore().getLastUpdateTime(user)) assertEquals("first", Statsig.getConfig("test_config").getString("key", "")) } diff --git a/src/test/java/com/statsig/androidsdk/TestUtil.kt b/src/test/java/com/statsig/androidsdk/TestUtil.kt index 3629561..52a7efa 100644 --- a/src/test/java/com/statsig/androidsdk/TestUtil.kt +++ b/src/test/java/com/statsig/androidsdk/TestUtil.kt @@ -3,17 +3,35 @@ package com.statsig.androidsdk import android.app.Application import android.content.SharedPreferences import com.google.gson.Gson +import com.google.gson.GsonBuilder +import com.google.gson.JsonArray +import com.google.gson.JsonElement +import com.google.gson.JsonObject +import com.google.gson.JsonPrimitive +import com.google.gson.JsonSerializationContext +import com.google.gson.JsonSerializer +import com.google.gson.TypeAdapter +import com.google.gson.TypeAdapterFactory +import com.google.gson.annotations.SerializedName +import com.google.gson.reflect.TypeToken +import com.google.gson.stream.JsonReader +import com.google.gson.stream.JsonWriter import io.mockk.* import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.TestCoroutineDispatcher import kotlinx.coroutines.test.setMain +import okhttp3.mockwebserver.Dispatcher +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import okhttp3.mockwebserver.RecordedRequest import org.junit.Assert -import java.io.IOException +import java.lang.reflect.Type import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit + @OptIn(ExperimentalCoroutinesApi::class) class TestUtil { companion object { @@ -198,7 +216,7 @@ class TestUtil { } @JvmName("startStatsigAndWait") - internal fun startStatsigAndWait(app: Application, user: StatsigUser = StatsigUser("jkw"), options: StatsigOptions = StatsigOptions(), network: StatsigNetwork? = null) = runBlocking { + internal fun startStatsigAndWait(app: Application, user: StatsigUser = StatsigUser("jkw"), options: StatsigOptions = StatsigOptions(), server: MockWebServer? = null) = runBlocking { val countdown = CountDownLatch(1) val callback = object : IStatsigCallback { override fun onStatsigInitialize() { @@ -211,8 +229,8 @@ class TestUtil { } Statsig.client = StatsigClient() - if (network != null) { - Statsig.client.statsigNetwork = network + if (server != null) { + useServer(Statsig.client, server) } Statsig.initializeAsync(app, "client-apikey", user, callback, options) countdown.await(1L, TimeUnit.SECONDS) @@ -237,15 +255,6 @@ class TestUtil { return mockk() } - @JvmName("captureLogs") - internal fun captureLogs(network: StatsigNetwork, onLog: ((LogEventData) -> Unit)? = null) { - coEvery { - network.apiPostLogs(any(), any(), any()) - } answers { - onLog?.invoke(Gson().fromJson(thirdArg(), LogEventData::class.java)) - } - } - fun stubAppFunctions(app: Application): TestSharedPreferences { val sharedPrefs = TestSharedPreferences() @@ -293,66 +302,126 @@ class TestUtil { } } - @JvmName("mockBrokenNetwork") - internal fun mockBrokenNetwork(): StatsigNetwork { - val statsigNetwork = mockk() - coEvery { - statsigNetwork.apiRetryFailedLogs(any(), any()) - } answers { - throw IOException("Example exception in StatsigNetwork apiRetryFailedLogs") - } - - coEvery { - statsigNetwork.initialize(any(), any(), any(), any(), any(), any(), any(), any(), any()) - } answers { - throw IOException("Example exception in StatsigNetwork initialize") - } - - coEvery { - statsigNetwork.addFailedLogRequest(any()) - } answers { - throw IOException("Example exception in StatsigNetwork addFailedLogRequest") - } - - coEvery { - statsigNetwork.apiPostLogs(any(), any(), any()) - } answers { - throw IOException("Example exception in StatsigNetwork apiPostLogs") + @JvmName("mockBrokenServer") + internal fun mockBrokenServer(): MockWebServer { + var server = MockWebServer() + server.apply { + dispatcher = object : Dispatcher() { + @Throws(InterruptedException::class) + override fun dispatch(request: RecordedRequest): MockResponse { + return MockResponse().setResponseCode(404) + } + } } - return statsigNetwork + useServer(server = server) + return server } - @JvmName("mockNetwork") - internal fun mockNetwork( + @JvmName("mockServer") + internal fun mockServer( featureGates: Map = dummyFeatureGates, dynamicConfigs: Map = dummyDynamicConfigs, layerConfigs: Map = dummyLayerConfigs, time: Long? = null, hasUpdates: Boolean = true, - captureUser: ((StatsigUser) -> Unit)? = null, - ): StatsigNetwork { - val statsigNetwork = mockk() - - coEvery { - statsigNetwork.apiRetryFailedLogs(any(), any()) - } returns Unit - - coEvery { - statsigNetwork.initialize(any(), any(), any(), any(), any(), any(), any(), any(), any()) - } coAnswers { - captureUser?.invoke(thirdArg()) - makeInitializeResponse(featureGates, dynamicConfigs, layerConfigs, time, hasUpdates) + onLog: ((LogEventData) -> Unit)? = null, + getInitializeResponse: ((InitializeRequestBody) -> InitializeResponse)? = null + ): MockWebServer { + var server = MockWebServer() + server.apply { + dispatcher = object : Dispatcher() { + @Throws(InterruptedException::class) + override fun dispatch(request: RecordedRequest): MockResponse { + when (request.path) { + "/v1/initialize" -> { + val requestBody = request.body.readUtf8() + var response: InitializeResponse = makeInitializeResponse(featureGates, dynamicConfigs, layerConfigs, time, hasUpdates) + if (getInitializeResponse != null) { + response = getInitializeResponse.invoke(Gson().fromJson(requestBody, InitializeRequestBody::class.java)) + } + val type = object : TypeToken>() {}.type + val gson = GsonBuilder().registerTypeAdapter( + type, PolymorphicSerializer() + ).create() + var stringified = gson.toJson(response) + return MockResponse().setResponseCode(200).setBody(stringified) + } + "/v1/log_event" -> { + val requestBody = request.body.readUtf8() + onLog?.invoke(Gson().fromJson(requestBody, LogEventData::class.java)) + return MockResponse().setResponseCode(200) + } + } + return MockResponse().setResponseCode(200) + } + } } + useServer(server = server) + return server + } - coEvery { - statsigNetwork.addFailedLogRequest(any()) - } answers {} + internal fun mockClientWithServer(client: StatsigClient? = null, server: MockWebServer): StatsigClient { + var mockClient: StatsigClient = if (client == null) spyk() else spyk(client) + every { + mockClient.options + } answers { + callOriginal().apply { api = server.url("/v1").toString() } + } + return mockClient + } - coEvery { - statsigNetwork.apiPostLogs(any(), any(), any()) - } answers {} + internal fun useServer(client: StatsigClient? = null, server: MockWebServer) { + Statsig.client = mockClientWithServer(client ?: Statsig.client, server) + } - return statsigNetwork + // Because Gson can't handle serializing polymorphic objects in Java + // we need a custom serializer to handle nested object values in APIDynamicConfig + internal class PolymorphicSerializer : JsonSerializer { + override fun serialize( + src: Any?, + typeOfSrc: Type?, + context: JsonSerializationContext? + ): JsonElement { + return src.let { + if (it is Map<*, *>) { + var obj = JsonObject() + for (item in it) { + val key = item.key + val value = item.value + if (key is String) { + obj.add(key, serialize(value, null, null)) + } + } + return@let obj + } else if (it is Map<*, *>) { + var obj = JsonObject() + for (item in it) { + val key = item.key + val value = item.value + if (key is String) { + obj.add(key, serialize(value, null, null)) + } + } + return@let obj + } else if (it is String) { + return@let JsonPrimitive(it) + } else if (it is Number) { + return@let JsonPrimitive(it) + } else if (it is Boolean) { + return@let JsonPrimitive(it) + } else if (it is Char) { + return@let JsonPrimitive(it) + } else if (it is Array<*>) { + var arr = JsonArray() + for (item in it) { + arr.add(serialize(item, null, null)) + } + return@let arr + } else { + return@let JsonPrimitive(it.toString()) + } + } + } } } }