Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import asyncio
from enum import Enum
from typing import Dict, List, Optional, TypedDict, Any

import io
from PIL import Image
import pytesseract
from langgraph.graph import StateGraph, END

from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
Expand Down Expand Up @@ -493,7 +495,49 @@ def reset_session(self, session_id: str):
self.chat_history.pop(session_id, None)
self.session_memory.pop(session_id, None)

async def extract_from_image(self, image_bytes: bytes, mime_type: str) -> str:
"""
Processes image bytes using Pytesseract and uses Gemini to
clean/format the results into a valid neuroscience query.
"""
try:
# 1. Convert bytes to an Image object
img = Image.open(io.BytesIO(image_bytes))

# 2. Run OCR in a background thread (to keep the app responsive)
raw_text = await asyncio.to_thread(pytesseract.image_to_string, img)

if not raw_text.strip():
return "The image appears to be empty or unreadable."

# 3. Use Gemini to "clean" the messy OCR text
# (OCR often results in typos or weird characters in scientific papers)
client = _get_genai_client()
clean_prompt = (
"Extract ONLY the scientific search terms from this OCR text. "
"Return the terms as a comma-separated list. "
"Do NOT include explanations, introductions, or extra text.\n\n"
f"OCR Text: {raw_text}"
)

cfg = genai_types.GenerateContentConfig(
temperature=0.1,
max_output_tokens=256
)

resp = client.models.generate_content(
model=FLASH_LITE_MODEL,
contents=[clean_prompt],
config=cfg
)
raw_output = (resp.text or raw_text).strip()
clean_text = raw_output.replace("**", "")
clean_text = "\n".join([line.lstrip("-* ").strip() for line in clean_text.splitlines()])
return clean_text

except Exception as e:
print(f"OCR Error: {e}")
return f"Error extracting text: {str(e)}"
async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> str:
try:
if reset:
Expand Down
28 changes: 27 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
Expand Down Expand Up @@ -108,6 +108,32 @@ async def health():
"timestamp": datetime.utcnow().isoformat(),
}

@app.post("/api/ocr", tags=["Chat"])
async def ocr_endpoint(file: UploadFile = File(...)):
"""
Receives an image, extracts neuroscience-related text using Gemini,
and returns it to be used as a chat query.
"""
try:
# 1. Read the uploaded image bytes
image_bytes = await file.read()

# 2. Use the assistant to process the image
# (We will add this method to your assistant next)
extracted_text = await assistant.extract_from_image(
image_bytes,
mime_type=file.content_type
)

return {"extracted_text": extracted_text}

except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to process image: {str(e)}"
)



@app.post("/api/chat", response_model=ChatResponse, tags=["Chat"])
async def chat_endpoint(msg: ChatMessage):
Expand Down
161 changes: 138 additions & 23 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ const App: React.FC = () => {
const [inputValue, setInputValue] = useState('');
const [isLoading, setIsLoading] = useState(false);
const [isOnline, setIsOnline] = useState(true);
const [uploadState, setUploadState] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle');
const [uploadError, setUploadError] = useState<string>('');
const chatContainerRef = useRef<HTMLDivElement>(null);

useEffect(() => {
Expand Down Expand Up @@ -75,6 +77,81 @@ Try asking me something like:
});
};

const handleImageUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0];
if (!file) return;

// Validate file type
if (!file.type.startsWith('image/')) {
setUploadState('error');
setUploadError('Only image files are supported.');
setTimeout(() => setUploadState('idle'), 3000);
e.target.value = '';
return;
}

// Validate file size (max 10 MB)
if (file.size > 10 * 1024 * 1024) {
setUploadState('error');
setUploadError('Image must be smaller than 10 MB.');
setTimeout(() => setUploadState('idle'), 3000);
e.target.value = '';
return;
}

setUploadState('uploading');
setUploadError('');
setIsLoading(true);

const formData = new FormData();
formData.append('file', file);

try {
const response = await fetch('/api/ocr', {
method: 'POST',
body: formData,
});

if (!response.ok) {
const detail = await response.json().catch(() => ({}));
throw new Error(detail?.detail || `Server error ${response.status}`);
}

const data = await response.json();

if (!data.extracted_text || !data.extracted_text.trim()) {
throw new Error('No readable text found in the image.');
}

setInputValue(data.extracted_text);
setUploadState('success');
// Reset success indicator after 2 s
setTimeout(() => setUploadState('idle'), 2000);
} catch (error: unknown) {
console.error('Upload error:', error);
const msg = error instanceof Error ? error.message : 'Failed to extract text from the image.';
setUploadState('error');
setUploadError(msg);
const errorMessage: Message = {
id: Date.now().toString(),
type: 'error',
content: `Image upload failed: ${msg} Please try a clearer screenshot.`,
timestamp: new Date()
};
setMessages(prev => [...prev, errorMessage]);
// Reset error indicator after 3 s
setTimeout(() => setUploadState('idle'), 3000);
} finally {
setIsLoading(false);
// Reset the file input so the same file can be uploaded again
e.target.value = '';
}
};





const sendMessage = async () => {
if (!inputValue.trim() || isLoading) return;

Expand Down Expand Up @@ -188,29 +265,67 @@ Try asking me something like:
{/* Input Area */}
<footer className="input-section">
<div className="input-container">
<div className="input-wrapper">
<input
type="text"
className="message-input"
placeholder="Ask about neuroscience datasets, brain imaging data, or research topics..."
value={inputValue}
onChange={(e) => setInputValue(e.target.value)}
onKeyPress={handleKeyPress}
disabled={isLoading}
/>
<button
className={`send-button ${isLoading || !inputValue.trim() ? 'disabled' : ''}`}
type="button"
onClick={sendMessage}
disabled={isLoading || !inputValue.trim()}
>
{isLoading ? (
<i className="fas fa-spinner fa-spin"></i>
) : (
<i className="fas fa-paper-plane"></i>
)}
</button>
</div>
<div className="input-wrapper" style={{ alignItems: 'flex-end' }}>
<input
type="file" id="image-upload" accept="image/*" hidden
onChange={handleImageUpload} disabled={isLoading || uploadState === 'uploading'}
/>

<label
htmlFor="image-upload"
className={`action-btn upload-btn upload-btn--${uploadState} ${(isLoading || uploadState === 'uploading') ? 'disabled' : ''}`}
title={
uploadState === 'uploading' ? 'Extracting text…' :
uploadState === 'success' ? 'Text extracted!' :
uploadState === 'error' ? uploadError || 'Upload failed' :
'Upload an image to extract search terms'
}
aria-label="Upload image for OCR"
>
{uploadState === 'uploading' && <i className="fas fa-spinner fa-spin"></i>}
{uploadState === 'success' && <i className="fas fa-check"></i>}
{uploadState === 'error' && <i className="fas fa-exclamation-triangle"></i>}
{uploadState === 'idle' && <i className="fas fa-paperclip"></i>}
</label>

{/* Dynamic Textarea */}
<textarea
className="message-input"
placeholder="Type or upload an image..."
value={inputValue}
rows={1}
onChange={(e) => {
setInputValue(e.target.value);
// Reset height to calculate correctly
e.target.style.height = 'inherit';
// Set new height based on scrollHeight, capped at 150px
e.target.style.height = `${Math.min(e.target.scrollHeight, 150)}px`;
}}
onKeyDown={(e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
// Reset height after sending
(e.target as HTMLTextAreaElement).style.height = 'inherit';
}
}}
style={{
resize: 'none',
overflowY: inputValue.split('\n').length > 5 ? 'auto' : 'hidden',
minHeight: '44px',
maxHeight: '150px'
}}
disabled={isLoading}
/>

<button
className={`send-button ${isLoading || !inputValue.trim() ? 'disabled' : ''}`}
onClick={sendMessage}
disabled={isLoading || !inputValue.trim()}
>
{isLoading ? <i className="fas fa-spinner fa-spin"></i> : <i className="fas fa-paper-plane"></i>}
</button>
</div>
<div className="input-footer">
<i className="fas fa-info-circle"></i>
<span>Powered by INCF KnowledgeSpace API - Neuroscience datasets</span>
Expand Down
35 changes: 35 additions & 0 deletions frontend/src/styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,41 @@ html, body {
transform: translateY(-1px);
}

/* Upload button state variants */
.upload-btn {
border: 2px solid transparent;
transition: background 0.2s ease, color 0.2s ease, border-color 0.2s ease, transform 0.2s ease;
position: relative;
}

.upload-btn--idle {
color: #64748b;
}

.upload-btn--uploading {
background: rgba(102, 126, 234, 0.12);
color: #667eea;
cursor: wait;
pointer-events: none;
}

.upload-btn--success {
background: rgba(16, 185, 129, 0.12);
color: #10b981;
border-color: rgba(16, 185, 129, 0.4);
}

.upload-btn--error {
background: rgba(239, 68, 68, 0.12);
color: #ef4444;
border-color: rgba(239, 68, 68, 0.4);
}

.upload-btn.disabled {
pointer-events: none;
opacity: 0.5;
}

/* Chat Container */
.chat-container {
flex: 1;
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ dependencies = [
"langgraph>=0.6.4",
"matplotlib>=3.10.3",
"pandas>=2.3.1",
"pillow>=12.1.0",
"pytesseract>=0.3.13",
"python-multipart>=0.0.21",
"requests>=2.32.4",
"scikit-learn>=1.7.0",
"sentence-transformers>=3.0.0",
Expand Down