Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: sort out inference code after refactor for finetune #9

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src.ts/broker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export async function createZGComputeNetworkBroker(
try {
const ledger = await createLedgerBroker(signer, ledgerCA)
// TODO: Adapts the usage of the ledger broker to initialize the inference broker.
const inferenceBroker = await createInferenceBroker(signer, inferenceCA)
const inferenceBroker = await createInferenceBroker(signer, inferenceCA, ledger)
const fineTuningBroker = await createFineTuningBroker(
signer,
fineTuningCA,
Expand Down
62 changes: 6 additions & 56 deletions src.ts/inference/broker/account.ts
Original file line number Diff line number Diff line change
@@ -1,73 +1,23 @@
import { ZGServingUserBrokerBase } from './base'
import { genKeyPair } from '../../common/settle-signer'
import { AddressLike } from 'ethers'
import { encryptData, privateKeyToStr } from '../../common/utils'
import {ZGServingUserBrokerBase} from './base'
import {genKeyPair} from '../../common/settle-signer'
import {AddressLike} from 'ethers'
import {encryptData, privateKeyToStr} from '../../common/utils'

/**
* AccountProcessor contains methods for creating, depositing funds, and retrieving 0G Serving Accounts.
*/
export class AccountProcessor extends ZGServingUserBrokerBase {
async getAccount(provider: AddressLike) {
try {
const account = await this.contract.getAccount(provider)
return account
return await this.contract.getAccount(provider)
} catch (error) {
throw error
}
}

async listAccount() {
try {
const accounts = await this.contract.listAccount()
return accounts
} catch (error) {
throw error
}
}

async addAccount(providerAddress: string, balance: number) {
try {
try {
const account = await this.getAccount(providerAddress)
if (account) {
throw new Error(
'Account already exists, with balance: ' +
this.neuronToA0gi(account.balance) +
' A0GI'
)
}
} catch (error) {
if (!(error as any).message.includes('AccountNotExists')) {
throw error
}
}

const { settleSignerPublicKey, settleSignerEncryptedPrivateKey } =
await this.createSettleSignerKey()

await this.contract.addAccount(
providerAddress,
settleSignerPublicKey,
this.a0giToNeuron(balance),
settleSignerEncryptedPrivateKey
)
} catch (error) {
throw error
}
}

async deleteAccount(provider: AddressLike) {
try {
await this.contract.deleteAccount(provider)
} catch (error) {
throw error
}
}

async depositFund(providerAddress: string, balance: number) {
try {
const amount = this.a0giToNeuron(balance).toString()
await this.contract.depositFund(providerAddress, amount)
return await this.contract.listAccount()
} catch (error) {
throw error
}
Expand Down
39 changes: 27 additions & 12 deletions src.ts/inference/broker/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export abstract class ZGServingUserBrokerBase {
constructor(
contract: InferenceServingContract,
metadata: Metadata,
cache: Cache
cache: Cache,
) {
this.contract = contract
this.metadata = metadata
Expand All @@ -33,7 +33,7 @@ export abstract class ZGServingUserBrokerBase {
protected async getService(
providerAddress: string,
svcName: string,
useCache = true
useCache = true,
): Promise<ServiceStructOutput> {
const key = providerAddress + svcName
const cachedSvc = await this.cache.getItem(key)
Expand All @@ -47,7 +47,7 @@ export abstract class ZGServingUserBrokerBase {
key,
svc,
1 * 60 * 1000,
CacheValueTypeEnum.Service
CacheValueTypeEnum.Service,
)
return svc
} catch (error) {
Expand All @@ -58,13 +58,13 @@ export abstract class ZGServingUserBrokerBase {
protected async getExtractor(
providerAddress: string,
svcName: string,
useCache = true
useCache = true,
): Promise<Extractor> {
try {
const svc = await this.getService(
providerAddress,
svcName,
useCache
useCache,
)
const extractor = this.createExtractor(svc)
return extractor
Expand Down Expand Up @@ -112,20 +112,19 @@ export abstract class ZGServingUserBrokerBase {
const integerPart = value / divisor
const remainder = value % divisor
const decimalPart = Number(remainder) / Number(divisor)

return Number(integerPart) + decimalPart
}

async getHeader(
providerAddress: string,
svcName: string,
content: string,
outputFee: bigint
outputFee: bigint,
): Promise<ServingRequestHeaders> {
try {
const extractor = await this.getExtractor(providerAddress, svcName)
const { settleSignerPrivateKey } = await this.getProviderData(
providerAddress
providerAddress,
)
const key = `${this.contract.getUserAddress()}_${providerAddress}`

Expand All @@ -134,7 +133,7 @@ export abstract class ZGServingUserBrokerBase {
const account = await this.contract.getAccount(providerAddress)
const privateKeyStr = await decryptData(
this.contract.signer,
account.additionalInfo
account.additionalInfo,
)
privateKey = strToPrivateKey(privateKeyStr)
this.metadata.storeSettleSignerPrivateKey(key, privateKey)
Expand All @@ -149,11 +148,11 @@ export abstract class ZGServingUserBrokerBase {
nonce.toString(),
fee.toString(),
this.contract.getUserAddress(),
providerAddress
providerAddress,
)
const settleSignature = await signData(
[request],
privateKey as PackedPrivkey
privateKey as PackedPrivkey,
)
const sig = JSON.stringify(Array.from(settleSignature[0]))

Expand All @@ -172,10 +171,26 @@ export abstract class ZGServingUserBrokerBase {
}
}

private async calculateInputFees(extractor: Extractor, content: string) {
async calculateInputFees(extractor: Extractor, content: string) {
const svc = await extractor.getSvcInfo()
const inputCount = await extractor.getInputCount(content)
const inputFee = BigInt(inputCount) * svc.inputPrice
return inputFee
}

getCachedFeeKey(provider: string, svcName: string) {
return provider + '_' + svcName + '_cachedFee'
}

async updateCachedFee(provider: string, svcName: string, fee: bigint) {
try {
const curFee = await this.cache.getItem(
this.getCachedFeeKey(provider, svcName)) || BigInt(0)
await this.cache.setItem(provider, BigInt(curFee) + fee,
1 * 60 * 1000, CacheValueTypeEnum.Service)
} catch (error) {
throw error
}
}

}
92 changes: 37 additions & 55 deletions src.ts/inference/broker/broker.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { AccountStructOutput, InferenceServingContract } from '../contract'
import { JsonRpcSigner, Wallet } from 'ethers'
import { RequestProcessor } from './request'
import { ResponseProcessor } from './response'
import { Verifier } from './verifier'
import { AccountProcessor } from './account'
import { ModelProcessor } from './model'
import { Metadata } from '../../common/storage'
import { Cache } from '../storage'
import {AccountStructOutput, InferenceServingContract, ServiceStructOutput} from '../contract'
import {JsonRpcSigner, Wallet} from 'ethers'
import {RequestProcessor} from './request'
import {ResponseProcessor} from './response'
import {Verifier} from './verifier'
import {AccountProcessor} from './account'
import {ModelProcessor} from './model'
import {Metadata} from '../../common/storage'
import {Cache} from '../storage'
import {LedgerBroker} from '../../ledger'

export class InferenceBroker {
public requestProcessor!: RequestProcessor
Expand All @@ -17,10 +18,12 @@ export class InferenceBroker {

private signer: JsonRpcSigner | Wallet
private contractAddress: string
private ledger: LedgerBroker

constructor(signer: JsonRpcSigner | Wallet, contractAddress: string) {
constructor(signer: JsonRpcSigner | Wallet, contractAddress: string, ledger: LedgerBroker) {
this.signer = signer
this.contractAddress = contractAddress
this.ledger = ledger
}

async initialize() {
Expand All @@ -37,12 +40,8 @@ export class InferenceBroker {
)
const metadata = new Metadata()
const cache = new Cache()
this.requestProcessor = new RequestProcessor(contract, metadata, cache)
this.responseProcessor = new ResponseProcessor(
contract,
metadata,
cache
)
this.requestProcessor = new RequestProcessor(contract, metadata, cache, this.ledger)
this.responseProcessor = new ResponseProcessor(contract, metadata, cache)
this.accountProcessor = new AccountProcessor(contract, metadata, cache)
this.modelProcessor = new ModelProcessor(contract, metadata, cache)
this.verifier = new Verifier(contract, metadata, cache)
Expand All @@ -62,28 +61,6 @@ export class InferenceBroker {
}
}

/**
* Adds a new account to the contract.
*
* @param {string} providerAddress - The address of the provider for whom the account is being created.
* @param {number} balance - The initial balance to be assigned to the new account. Units are in A0GI.
*
* @throws An error if the account creation fails.
*
* @remarks
* When creating an account, a key pair is also created to sign the request.
*/
public addAccount = async (providerAddress: string, balance: number) => {
try {
return await this.accountProcessor.addAccount(
providerAddress,
balance
)
} catch (error) {
throw error
}
}

/**
* Retrieves the account information for a given provider address.
*
Expand All @@ -103,21 +80,6 @@ export class InferenceBroker {
}
}

/**
* Deposits a specified amount of funds into the given account.
*
* @param {string} account - The account identifier where the funds will be deposited.
* @param {string} amount - The amount of funds to be deposited. Units are in A0GI.
* @throws An error if the deposit fails.
*/
public depositFund = async (account: string, amount: number) => {
try {
return await this.accountProcessor.depositFund(account, amount)
} catch (error) {
throw error
}
}

/**
* Generates request metadata for the provider service.
* Includes:
Expand Down Expand Up @@ -356,6 +318,25 @@ export class InferenceBroker {
throw error
}
}

/**
* retrieve fund from all inference account back to ledger
*/
public retrieveFund = async () => {
try {
const ledger = await this.ledger.getLedger()
let retrieveProviders = []
for (const provider of ledger.inferenceProviders) {
const acc = await this.getAccount(provider)
if (acc.balance > 0) {
retrieveProviders.push(provider)
}
}
await this.ledger.retrieveFund(retrieveProviders, "inference")
} catch (error) {
throw error
}
}
}

/**
Expand All @@ -370,9 +351,10 @@ export class InferenceBroker {
*/
export async function createInferenceBroker(
signer: JsonRpcSigner | Wallet,
contractAddress = ''
contractAddress = '',
ledger: LedgerBroker
): Promise<InferenceBroker> {
const broker = new InferenceBroker(signer, contractAddress)
const broker = new InferenceBroker(signer, contractAddress, ledger)
try {
await broker.initialize()
return broker
Expand Down
5 changes: 3 additions & 2 deletions src.ts/inference/broker/model.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ServiceStructOutput } from '../contract'
import { ZGServingUserBrokerBase } from './base'
import {ServiceStructOutput} from '../contract'
import {ZGServingUserBrokerBase} from './base'

export enum VerifiabilityEnum {
OpML = 'OpML',
Expand All @@ -23,6 +23,7 @@ export class ModelProcessor extends ZGServingUserBrokerBase {
}
}


export function isVerifiability(value: string): value is Verifiability {
return Object.values(VerifiabilityEnum).includes(value as VerifiabilityEnum)
}
Loading