Skip to content

Commit

Permalink
Refactoring.
Browse files Browse the repository at this point in the history
Signed-off-by: Ralph Gasser <[email protected]>
  • Loading branch information
ppanopticon committed Nov 25, 2024
1 parent 1f0d2f0 commit e483891
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.vitrivr.engine.module.torchserve.client

import com.google.auth.oauth2.AccessToken
import com.google.auth.oauth2.OAuth2Credentials
import io.grpc.Grpc
import io.grpc.InsecureChannelCredentials
import io.grpc.ManagedChannel
import io.grpc.auth.MoreCallCredentials
import java.io.Closeable
import java.util.*

/**
* An abstract client for connecting to a TorchServe instance via gRPC
*
* @author Ralph Gasser
* @version 1.0.0
*/
abstract class AbstractTorchServeClient(val host: String, val port: Int = 8080, private val token: String? = null) : Closeable {

/** Credentials used for connecting to TorchServe. */
protected val credentials = this.token?.let {
MoreCallCredentials.from(OAuth2Credentials.create(AccessToken(it, Date(Long.MAX_VALUE))))
}

/** The [ManagedChannel] used for gRPC communication. */
protected val channel: ManagedChannel by lazy { Grpc.newChannelBuilderForAddress(this.host, this.port, InsecureChannelCredentials.create()).build() }

/**
* Closes this [AbstractTorchServeClient].
*/
override fun close() {
this.channel.shutdownNow()
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,33 +1,18 @@
package org.vitrivr.engine.module.torchserve.client

import com.google.auth.oauth2.AccessToken
import com.google.auth.oauth2.OAuth2Credentials
import com.google.protobuf.ByteString
import io.grpc.Grpc
import io.grpc.InsecureChannelCredentials
import io.grpc.auth.MoreCallCredentials
import org.pytorch.serve.grpc.inference.InferenceAPIsServiceGrpc
import org.pytorch.serve.grpc.inference.predictionsRequest
import java.io.Closeable
import java.util.*


/**
* A client for connecting to a TorchServe instance and sending inference requests.
*
* @author Luca Rossetto
* @author Ralph Gasser
* @version 1.0.03
* @version 1.0.3
*/
class InferenceClient(val host: String, val port: Int = 8080, private val token: String? = null) : Closeable {

/** Credentials used for connecting to TorchServe. */
private val credentials = this.token?.let {
MoreCallCredentials.from(OAuth2Credentials.create(AccessToken(it, Date(Long.MAX_VALUE))))
}

/** */
private val channel by lazy { Grpc.newChannelBuilderForAddress(this.host, this.port, InsecureChannelCredentials.create()).build() }
class InferenceClient(host: String, port: Int = 8080, token: String? = null) : AbstractTorchServeClient(host, port, token) {

/** The stub used to communicate with TorchServe. */
private val blockingStub by lazy { InferenceAPIsServiceGrpc.newBlockingStub(this.channel) }
Expand Down Expand Up @@ -55,11 +40,4 @@ class InferenceClient(val host: String, val port: Int = 8080, private val token:
)
return response.prediction
}

/**
* Closes this [InferenceClient].
*/
override fun close() {
this.channel.shutdownNow()
}
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,24 @@
package org.vitrivr.engine.module.torchserve.client

import com.google.auth.oauth2.AccessToken
import com.google.auth.oauth2.OAuth2Credentials
import io.grpc.Grpc
import io.grpc.InsecureChannelCredentials
import io.grpc.ManagedChannel
import io.grpc.auth.MoreCallCredentials
import org.pytorch.serve.grpc.management.ManagementAPIsServiceGrpc
import org.pytorch.serve.grpc.management.describeModelRequest
import org.pytorch.serve.grpc.management.listModelsRequest
import java.util.*

class ManagementClient(private val host: String, private val port: Int = 7071, private val token: String? = null) {
/**
* A client for connecting to a TorchServe instance and sending management requests via gRPC.
*
* @author Luca Rossetto
* @author Ralph Gasser
* @version 1.0.3
*/
class ManagementClient(host: String, port: Int = 7071, token: String? = null) : AbstractTorchServeClient(host, port, token) {

private val credentials = token?.let {
MoreCallCredentials.from(
OAuth2Credentials.create(
AccessToken(
"Bearer: $it",
Date(Long.MAX_VALUE)
)
)
)
}
private lateinit var channel: ManagedChannel
private lateinit var stub: ManagementAPIsServiceGrpc.ManagementAPIsServiceBlockingStub
fun connect() {

this.channel = Grpc.newChannelBuilderForAddress(host, port, InsecureChannelCredentials.create()).build()
this.stub = ManagementAPIsServiceGrpc.newBlockingStub(channel)

}


/**
*
*/
fun listModels(): String { //TODO do some parsing?

val response = this.stub.let {
Expand All @@ -48,6 +34,9 @@ class ManagementClient(private val host: String, private val port: Int = 7071, p

}

/**
*
*/
fun describeModel(name: String, version: String? = null): String {

return this.stub.let {
Expand Down

0 comments on commit e483891

Please sign in to comment.