Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 162 additions & 81 deletions cli/tri/src/depin/prove.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,58 @@
use crate::depin::phi_challenge::{compute_epoch_hash, derive_phi_challenge, verify_phi_response};
use crate::depin::phi_challenge::{
compute_epoch_hash, derive_phi_challenge, derive_phi_challenge_v2, verify_phi_response,
verify_phi_response_v2,
};
use crate::depin::types::{AppState, EpochChallengeResponse, ProveRequest, ProveResponse};
use crate::depin::types::sha2_hash;

fn err(reason: &str) -> axum::Json<ProveResponse> {
axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some(reason.into()),
})
}

pub async fn post_prove(
axum::extract::State(state): axum::extract::State<std::sync::Arc<tokio::sync::RwLock<AppState>>>,
axum::Json(req): axum::Json<ProveRequest>,
) -> axum::Json<ProveResponse> {
let node_id = match hex::decode(&req.node_id) {
Ok(v) if v.len() == 32 => v,
_ => {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("invalid_node_id".into()),
});
}
_ => return err("invalid_node_id"),
};

if req.version != 1 && req.version != 2 {
return err("unsupported_version");
}

let expected_resp_len = if req.version == 2 { 32 } else { 4 };
let phi_response = match hex::decode(&req.phi_response) {
Ok(v) if v.len() == 4 => v,
_ => {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("invalid_phi_response".into()),
});
}
Ok(v) if v.len() == expected_resp_len => v,
_ => return err("invalid_phi_response"),
};

let challenge = derive_phi_challenge(req.epoch, &node_id);
if !verify_phi_response(&challenge, &phi_response, &node_id) {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("phi_challenge_mismatch".into()),
});
let challenge_ok = if req.version == 2 {
let node_arr: [u8; 32] = match node_id.as_slice().try_into() {
Ok(a) => a,
Err(_) => return err("invalid_node_id"),
};
let resp_arr: [u8; 32] = match phi_response.as_slice().try_into() {
Ok(a) => a,
Err(_) => return err("invalid_phi_response"),
};
let challenge = derive_phi_challenge_v2(req.epoch, &node_arr);
verify_phi_response_v2(&challenge, &resp_arr, &node_arr)
} else {
let challenge = derive_phi_challenge(req.epoch, &node_id);
verify_phi_response(&challenge, &phi_response, &node_id)
};

if !challenge_ok {
return err("phi_challenge_mismatch");
}

let root = match hex::decode(&req.merkle_proof.root) {
Expand All @@ -52,16 +61,7 @@ pub async fn post_prove(
arr.copy_from_slice(&v);
arr
}
_ => {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("merkle_proof_invalid".into()),
});
}
_ => return err("merkle_proof_invalid"),
};

let leaf = match hex::decode(&req.merkle_proof.leaf) {
Expand All @@ -70,16 +70,7 @@ pub async fn post_prove(
arr.copy_from_slice(&v);
arr
}
_ => {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("merkle_proof_invalid".into()),
});
}
_ => return err("merkle_proof_invalid"),
};

let siblings: Vec<[u8; 32]> = req
Expand All @@ -99,50 +90,38 @@ pub async fn post_prove(
.collect();

if siblings.len() != req.merkle_proof.siblings.len() {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("merkle_proof_invalid".into()),
});
return err("merkle_proof_invalid");
}

if !crate::depin::merkle::verify_merkle(&root, &leaf, &siblings, req.merkle_leaf_index) {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("merkle_proof_invalid".into()),
});
return err("merkle_proof_invalid");
}

if !verify_ed25519_signature(&node_id, &phi_response, &req.peer_sample_sig) {
return axum::Json(ProveResponse {
valid: false,
reward_lamports: 0,
epoch_hash: String::new(),
next_challenge: String::new(),
tokens_count: 0,
reason: Some("peer_sample_sig_invalid".into()),
});
if !verify_ed25519_signature(&node_id, &phi_response, &req.peer_sample_sig, req.version) {
return err("peer_sample_sig_invalid");
}

let guard = state.read().await;
let epoch = &guard.epoch;
let reward = epoch.block_reward;

let epoch_hash = compute_epoch_hash(req.epoch, &node_id, &phi_response);
let next = derive_phi_challenge(req.epoch + 1, &node_id);
let next_challenge_hex = if req.version == 2 {
let node_arr: [u8; 32] = match node_id.as_slice().try_into() {
Ok(a) => a,
Err(_) => return err("invalid_node_id"),
};
let next_matrix = derive_phi_challenge_v2(req.epoch + 1, &node_arr);
hex::encode(crate::depin::phi_challenge::pack_gf16_matrix(&next_matrix))
} else {
hex::encode(derive_phi_challenge(req.epoch + 1, &node_id))
};

axum::Json(ProveResponse {
valid: true,
reward_lamports: reward,
epoch_hash: hex::encode(epoch_hash),
next_challenge: hex::encode(next),
next_challenge: next_challenge_hex,
tokens_count: reward / 1000,
reason: None,
})
Expand All @@ -169,7 +148,7 @@ pub async fn get_epoch_challenge(
})
}

fn verify_ed25519_signature(node_id: &[u8], phi_response: &[u8], sig_hex: &str) -> bool {
fn verify_ed25519_signature(node_id: &[u8], phi_response: &[u8], sig_hex: &str, version: u8) -> bool {
let sig_bytes = match hex::decode(sig_hex) {
Ok(v) => v,
Err(_) => return false,
Expand All @@ -178,8 +157,9 @@ fn verify_ed25519_signature(node_id: &[u8], phi_response: &[u8], sig_hex: &str)
return false;
}

let domain: &[u8] = if version == 2 { b"TRI_PROVE_V2" } else { b"TRI_PROVE_V1" };
let mut message = Vec::new();
message.extend_from_slice(b"TRI_PROVE_V1");
message.extend_from_slice(domain);
message.extend_from_slice(node_id);
message.extend_from_slice(phi_response);

Expand Down Expand Up @@ -277,6 +257,7 @@ mod tests {
},
merkle_leaf_index: 0,
peer_sample_sig: hex::encode(sig.to_bytes()),
version: 1,
}
}

Expand Down Expand Up @@ -374,6 +355,7 @@ mod tests {
},
merkle_leaf_index: 2,
peer_sample_sig: hex::encode(sig.to_bytes()),
version: 1,
};
let resp = call_prove(&app, req).await;
assert!(resp.valid, "expected valid proof with 4-leaf merkle tree, got reason: {:?}", resp.reason);
Expand Down Expand Up @@ -403,4 +385,103 @@ mod tests {
}
siblings
}

use crate::depin::phi_challenge::{compute_phi_response_v2, derive_phi_challenge_v2};

fn make_valid_proof_request_v2(epoch: u64) -> (ProveRequest, [u8; 32]) {
let signing_key_bytes = [0xCC; 32];
let signing_key = SigningKey::from_bytes(&signing_key_bytes);
let verifying_key = signing_key.verifying_key();
let node_id = verifying_key.to_bytes();

let challenge = derive_phi_challenge_v2(epoch, &node_id);
let phi_response = compute_phi_response_v2(&challenge);

let leaf = crate::depin::types::sha2_hash(&[&node_id, &phi_response]);
let leaves = vec![leaf];
let root = merkle_root(&leaves);

let mut message = Vec::new();
message.extend_from_slice(b"TRI_PROVE_V2");
message.extend_from_slice(&node_id);
message.extend_from_slice(&phi_response);
let sig = signing_key.sign(&message);

let req = ProveRequest {
node_id: hex::encode(node_id),
epoch,
phi_response: hex::encode(phi_response),
merkle_proof: MerkleProof {
root: hex::encode(root),
leaf: hex::encode(leaf),
siblings: vec![],
},
merkle_leaf_index: 0,
peer_sample_sig: hex::encode(sig.to_bytes()),
version: 2,
};
(req, phi_response)
}

#[tokio::test]
async fn test_v2_e2e_valid_proof() {
let app = build_test_app();
let (req, _) = make_valid_proof_request_v2(0);
let resp = call_prove(&app, req).await;
assert!(resp.valid, "expected valid v2 proof, got reason: {:?}", resp.reason);
assert_eq!(resp.reward_lamports, 50_000_000);
assert_eq!(resp.tokens_count, 50_000);
assert!(resp.reason.is_none());
assert_eq!(resp.next_challenge.len(), 256, "v2 next_challenge is 128 bytes = 256 hex chars");
}

#[tokio::test]
async fn test_v2_e2e_invalid_response_flipped_bit() {
let app = build_test_app();
let (mut req, _) = make_valid_proof_request_v2(0);
let mut bytes = hex::decode(&req.phi_response).unwrap();
bytes[0] ^= 0x01;
req.phi_response = hex::encode(&bytes);
let resp = call_prove(&app, req).await;
assert!(!resp.valid);
assert_eq!(resp.reason.as_deref(), Some("phi_challenge_mismatch"));
}

#[tokio::test]
async fn test_v2_e2e_wrong_response_length() {
let app = build_test_app();
let (mut req, _) = make_valid_proof_request_v2(0);
req.phi_response = hex::encode([0xAA; 4]);
let resp = call_prove(&app, req).await;
assert!(!resp.valid);
assert_eq!(resp.reason.as_deref(), Some("invalid_phi_response"));
}

#[tokio::test]
async fn test_v2_e2e_unsupported_version() {
let app = build_test_app();
let (mut req, _) = make_valid_proof_request_v2(0);
req.version = 99;
let resp = call_prove(&app, req).await;
assert!(!resp.valid);
assert_eq!(resp.reason.as_deref(), Some("unsupported_version"));
}

#[test]
fn test_v2_kat_pinned_response() {
let mut node_id = [0u8; 32];
for (i, b) in node_id.iter_mut().enumerate() {
*b = (i + 1) as u8;
}
let challenge = derive_phi_challenge_v2(42, &node_id);
let response = compute_phi_response_v2(&challenge);
let hex_resp = hex::encode(response);
assert_eq!(hex_resp.len(), 64);
let challenge_again = derive_phi_challenge_v2(42, &node_id);
let response_again = compute_phi_response_v2(&challenge_again);
assert_eq!(response, response_again, "v2 KAT must be deterministic");
let other = derive_phi_challenge_v2(43, &node_id);
let other_resp = compute_phi_response_v2(&other);
assert_ne!(response, other_resp, "different epoch must yield different response");
}
}
6 changes: 6 additions & 0 deletions cli/tri/src/depin/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ pub struct ProveRequest {
pub merkle_proof: MerkleProof,
pub merkle_leaf_index: usize,
pub peer_sample_sig: String,
#[serde(default = "default_version")]
pub version: u8,
}

fn default_version() -> u8 {
1
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
Loading
Loading