Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 82 additions & 25 deletions packages/wallet/core/src/signers/session-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
SessionSigner,
SessionSignerInvalidReason,
isImplicitSessionSigner,
isIncrementCall,
UsageLimit,
} from './session/index.js'

Expand Down Expand Up @@ -130,21 +131,16 @@ export class SessionManager implements SapientSigner {
}))
}

async findSignersForCalls(wallet: Address.Address, chainId: number, calls: Payload.Call[]): Promise<SessionSigner[]> {
// Only use signers that match the topology
const topology = await this.topology
const identitySigners = SessionConfig.getIdentitySigners(topology)
if (identitySigners.length === 0) {
throw new Error('Identity signers not found')
}

// Prioritize implicit signers
const availableSigners = [...this._implicitSigners, ...this._explicitSigners]
if (availableSigners.length === 0) {
throw new Error('No signers match the topology')
}

// Find supported signers for each call
/**
* Find one signer per call from the given candidate list (first that supports each call).
*/
private async findSignersForCallsWithCandidates(
wallet: Address.Address,
chainId: number,
calls: Payload.Call[],
topology: SessionConfig.SessionsTopology,
availableSigners: SessionSigner[],
): Promise<SessionSigner[]> {
const signers: SessionSigner[] = []
for (const call of calls) {
let supported = false
Expand Down Expand Up @@ -173,9 +169,67 @@ export class SessionManager implements SapientSigner {
if (expiredSupportedSigner) {
throw new Error(`Signer supporting call is expired: ${expiredSupportedSigner.address}`)
}
throw new Error(
`No signer supported for call. ` + `Call: to=${call.to}, data=${call.data}, value=${call.value}, `,
)
throw new Error(`No signer supported for call. Call: to=${call.to}, data=${call.data}, value=${call.value}, `)
}
}
return signers
}

async findSignersForCalls(wallet: Address.Address, chainId: number, calls: Payload.Call[]): Promise<SessionSigner[]> {
const topology = await this.topology
const identitySigners = SessionConfig.getIdentitySigners(topology)
if (identitySigners.length === 0) {
throw new Error('Identity signers not found')
}

const availableSigners = [...this._implicitSigners, ...this._explicitSigners]
if (availableSigners.length === 0) {
throw new Error('No signers match the topology')
}

const nonIncrementCalls: Payload.Call[] = []
const incrementCalls: Payload.Call[] = []
for (const call of calls) {
if (isIncrementCall(call, this.address)) {
incrementCalls.push(call)
} else {
nonIncrementCalls.push(call)
}
}

// Find signers for non-increment calls
const nonIncrementSigners =
nonIncrementCalls.length > 0
? await this.findSignersForCallsWithCandidates(wallet, chainId, nonIncrementCalls, topology, availableSigners)
: []

let incrementSigners: SessionSigner[] = []
if (incrementCalls.length > 0) {
// Find signers for increment calls, preferring signers that signed non-increment calls
const incrementCandidates = [
...nonIncrementSigners,
...availableSigners.filter((s) => !nonIncrementSigners.includes(s)),
]
incrementSigners = await this.findSignersForCallsWithCandidates(
wallet,
chainId,
incrementCalls,
topology,
incrementCandidates,
)
}

// Merge back in original call order
const signers: SessionSigner[] = []
let nonIncrementIndex = 0
let incrementIndex = 0
for (const call of calls) {
if (isIncrementCall(call, this.address)) {
signers.push(incrementSigners[incrementIndex]!)
incrementIndex++
} else {
signers.push(nonIncrementSigners[nonIncrementIndex]!)
nonIncrementIndex++
}
}
return signers
Expand All @@ -191,20 +245,23 @@ export class SessionManager implements SapientSigner {
}
const signers = await this.findSignersForCalls(wallet, chainId, calls)

// Create a map of signers to their associated calls
const signerToCalls = new Map<SessionSigner, Payload.Call[]>()
// Map each signer to only their non-increment calls
const signerToNonIncrementCalls = new Map<SessionSigner, Payload.Call[]>()
signers.forEach((signer, index) => {
const call = calls[index]!
const existingCalls = signerToCalls.get(signer) || []
signerToCalls.set(signer, [...existingCalls, call])
if (isIncrementCall(call, this.address)) {
return
}
const existing = signerToNonIncrementCalls.get(signer) || []
signerToNonIncrementCalls.set(signer, [...existing, call])
})

// Prepare increments for each explicit signer with their associated calls
// Prepare increments for each explicit signer from their non-increment calls only
const increments: UsageLimit[] = (
await Promise.all(
Array.from(signerToCalls.entries()).map(async ([signer, associatedCalls]) => {
Array.from(signerToNonIncrementCalls.entries()).map(async ([signer, nonIncrementCalls]) => {
if (isExplicitSessionSigner(signer)) {
return signer.prepareIncrements(wallet, chainId, associatedCalls, this.address, this._provider!)
return signer.prepareIncrements(wallet, chainId, nonIncrementCalls, this.address, this._provider!)
}
return []
}),
Expand Down
8 changes: 2 additions & 6 deletions packages/wallet/core/src/signers/session/explicit.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Constants, Payload, Permission, SessionConfig, SessionSignature } from '@0xsequence/wallet-primitives'
import { AbiFunction, AbiParameters, Address, Bytes, Hash, Hex, Provider } from 'ox'
import { MemoryPkStore, PkStore } from '../pk/index.js'
import { ExplicitSessionSigner, SessionSignerValidity, UsageLimit } from './session.js'
import { ExplicitSessionSigner, isIncrementCall, SessionSignerValidity, UsageLimit } from './session.js'

export type ExplicitParams = Omit<Permission.SessionPermissions, 'signer'>

Expand Down Expand Up @@ -208,11 +208,7 @@ export class Explicit implements ExplicitSessionSigner {
sessionManagerAddress: Address.Address,
provider?: Provider.Provider,
): Promise<boolean> {
if (
Address.isEqual(call.to, sessionManagerAddress) &&
Hex.size(call.data) > 4 &&
Hex.isEqual(Hex.slice(call.data, 0, 4), AbiFunction.getSelector(Constants.INCREMENT_USAGE_LIMIT))
) {
if (isIncrementCall(call, sessionManagerAddress)) {
// Can sign increment usage calls
return true
}
Expand Down
12 changes: 10 additions & 2 deletions packages/wallet/core/src/signers/session/session.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Payload, SessionConfig, SessionSignature } from '@0xsequence/wallet-primitives'
import { Address, Hex, Provider } from 'ox'
import { Constants, Payload, SessionConfig, SessionSignature } from '@0xsequence/wallet-primitives'
import { AbiFunction, Address, Hex, Provider } from 'ox'

export type SessionSignerInvalidReason =
| 'Expired'
Expand Down Expand Up @@ -68,3 +68,11 @@ export function isExplicitSessionSigner(signer: SessionSigner): signer is Explic
export function isImplicitSessionSigner(signer: SessionSigner): signer is ImplicitSessionSigner {
return 'identitySigner' in signer
}

export function isIncrementCall(call: Payload.Call, sessionManagerAddress: Address.Address): boolean {
return (
Address.isEqual(call.to, sessionManagerAddress) &&
Hex.size(call.data) >= 4 &&
Hex.isEqual(Hex.slice(call.data, 0, 4), AbiFunction.getSelector(Constants.INCREMENT_USAGE_LIMIT))
)
}
2 changes: 2 additions & 0 deletions packages/wallet/core/test/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { Abi, AbiEvent, Address } from 'ox'
const envFile = process.env.CI ? '.env.test' : '.env.test.local'
dotenvConfig({ path: envFile })

// Contracts are deployed on Arbitrum

// Requires https://example.com redirectUrl
export const EMITTER_ADDRESS1: Address.Address = '0xad90eB52BC180Bd9f66f50981E196f3E996278D3'
// Requires https://another-example.com redirectUrl
Expand Down
Loading