アプリ開発備忘録

PlayStationMobile、Android、UWPの開発備忘録

【JVM/Kotlin】サーバー間でgRPC接続が特定インスタンスに偏る問題を修正する

サーバー間でのやりとりにgRPC接続を使用していますが、1つのインスタンスに接続が偏ってしまいました。これを解消します。

現在の実装

grpc-javaのバージョン1.53.0を使用しています。
https://github.com/grpc/grpc-java

現在、クライアント側は以下のような実装になっています。1つのインスタンスだけにアクセスが行き、高負荷で死ぬと次のインスタンスに切り替わるといった挙動になっていました。

import io.grpc.ManagedChannel

private val channel: ManagedChannel = ManagedChannelBuilder
    .forAddress(host, port)
    .usePlaintext()
    .build()

続いて、ロードバランスポリシーを設定してみましたが、スケールアウトしたものにアクセスが行きませんでした。

import io.grpc.ManagedChannel

private val channel: ManagedChannel = ManagedChannelBuilder
    .forAddress(host, port)
    .defaultLoadBalancingPolicy("round_robin")
    .usePlaintext()
    .build()

ロードバランサRoundRobinLoadBalancerの実装はこちら。
https://github.com/grpc/grpc-java/blob/136665f00ee74e5eabe6c846894e00b748cfb253/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java#L73

修正

NameResolver

RoundRobinLoadBalancer の実装を見ると、DnsNameResolverListener2.onResult() を呼ぶとコネクションが更新されていました。定期的にDNSを確認し、IPに変更があったらこれを呼びます。

DnsNameResolverを参考に実装していきます。
https://github.com/grpc/grpc-java/blob/136665f00ee74e5eabe6c846894e00b748cfb253/core/src/main/java/io/grpc/internal/DnsNameResolver.java

DnsNameResolver では java.net.InetAddress.getAllByName() が使用されていましたが、うまく更新されていなかったので dnsjava を使用しました。ttlとかそのあたり調べれば、別に java.net.InetAddress でも良さそう。
https://github.com/dnsjava/dnsjava

定期的にDNSのAレコードを監視してIPを更新するようにします。

import io.grpc.EquivalentAddressGroup
import io.grpc.NameResolver
import io.grpc.Status
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.xbill.DNS.Name
import org.xbill.DNS.Type
import org.xbill.DNS.lookup.LookupSession
import java.net.InetSocketAddress
import java.net.URI
import kotlin.time.Duration.Companion.minutes

internal class ARecordResolver(
    name: String,
) : NameResolver() {
    private val host: String
    private val port: Int
    private val authority: String

    private var listener: Listener2? = null
    private var shutdown = false

    private var resolvedResult: List<String> = listOf()
    private val coroutineScope = CoroutineScope(Dispatchers.Default)

    init {
        val nameUri = URI.create("//$name")
        host = nameUri.host
        port = nameUri.port
        authority = nameUri.authority

        coroutineScope.launch {
            while (true) {
                delay(0.5.minutes)
                withContext(Dispatchers.IO) {
                    resolve()
                }
            }
        }
    }

    override fun getServiceAuthority(): String = authority

    override fun start(listener: Listener2?) {
        listener ?: return
        this.listener = listener
        resolve()
    }

    override fun shutdown() {
        if (shutdown) return
        coroutineScope.cancel()
        shutdown = true
    }

    override fun refresh() {
        resolve()
    }

    private fun resolve() {
        if (shutdown) return
        val listener = listener ?: return
        coroutineScope.launch {
            runCatching {
                val newResult = getARecord().sorted()
                if (newResult != resolvedResult) {
                    println("IP Changed: $host=$newResult")
                    resolvedResult = newResult
                    listener.onResult(
                        ResolutionResult.newBuilder()
                            .setAddresses(
                                newResult.map { ip ->
                                    EquivalentAddressGroup(InetSocketAddress(ip, port))
                                },
                            )
                            .build(),
                    )
                }
            }.onFailure { e ->
                listener.onError(
                    Status.UNAVAILABLE
                        .withDescription("Unable to resolve host $host")
                        .withCause(e),
                )
            }
        }
    }

    private fun getARecord(): List<String> {
        return LookupSession.defaultBuilder().build()
            .lookupAsync(
                Name.fromString(host),
                Type.A,
            )
            .toCompletableFuture()
            .get()
            .records
            .map { it.rdataToString() }
    }
}

Provider

Providerを作成します。priorityはデフォルトの5より小さくしました。

import com.google.common.base.Preconditions
import io.grpc.NameResolver
import io.grpc.NameResolverProvider
import java.net.URI

public class ARecordResolverProvider : NameResolverProvider() {
    override fun newNameResolver(targetUri: URI?, args: NameResolver.Args?): NameResolver? {
        if (targetUri?.scheme == SCHEME) {
            val targetPath = targetUri.path
            Preconditions.checkArgument(
                targetPath.startsWith("/"),
                "the path component (%s) of the target (%s) must start with '/'",
                targetPath,
                targetUri,
            )
            val name: String = targetPath.substring(1)
            return ARecordResolver(
                name = name,
            )
        }
        return null
    }

    override fun getDefaultScheme(): String = SCHEME

    override fun isAvailable(): Boolean = true

    override fun priority(): Int = 1

    public companion object {
        private const val SCHEME = "a"
    }
}

後は最初に呼びます。

NameResolverRegistry
    .getDefaultRegistry()
    .register(ARecordResolverProvider())

channelを作成する部分はforTargetに変えます。Providerで指定した形式のスキームを設定します。

import io.grpc.ManagedChannel

private val channel: ManagedChannel = ManagedChannelBuilder
    .forTarget("a:///$host:$port")
    .defaultLoadBalancingPolicy("round_robin")
    .usePlaintext()
    .build()

おわりに

モジュール分割している場合は ManagedChannel を作成しているモジュールで io.grpc:grpc-core が無いと動きませんでした。
これで複数のインスタンスにアクセスが行きます。