Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions MiniCLIP-ViT
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#در این فایل قراره یک مینی clip با هدف درک معماری clip
# واقعی که هسته AGI است

import torch
import torch.nn as nn
import torch.nn.functional as F

#---creat paths---#
class PathEmbedding(nn.Module):

def __init__(self,img_size=32,paths_size=8,in_channels=3,embd_dim=128):
super().__init__()

self.num_paths = (img_size // paths_size) ** 2

self.project = nn.Conv2d(
in_channels,embd_dim,
kernel_size=paths_size,stride=paths_size
)


def forward(self,x):

x = self.project(x)
x = x.flatten(2)
x = x.transpose(1,2)
return x


#---ViTBlock---#
class VitBlock(nn.Module):

def __init__(self,embd_dim,num_heads):
super().__init__()

self.att_head = nn.MultiheadAttention(embd_dim,num_heads,batch_first=True)

self.mlp = nn.Sequential(
nn.Linear(embd_dim*4,embd_dim),
nn.GELU(),
nn.Linear(embd_dim,embd_dim*4)
)

self.norm1 = nn.LayerNorm(embd_dim)
self.norm2 = nn.LayerNorm(embd_dim)


def forward(self,x):

att, _ = self.att_head(self.norm1(x),self.norm1(x),self.norm1(x))
x = x + att
x = x + self.mlp(self.norm2(x))
return x


#---creat MiniViT----#
class VitImageEncoder(nn.Module):

def __init__(self,embd_dim=128,path_size=8,num_heads=3,depth=4,img_size=32):
super().__init__()

self.path_embedding = PathEmbedding(img_size,path_size,3,embd_dim)

self.cls = nn.Parameter(torch.randn(1,1,embd_dim))

# Corrected: use num_paths from the instantiated path_embedding
self.pos_path = nn.Parameter(torch.randn(1,1+self.path_embedding.num_paths,embd_dim))

self.blocks = nn.Sequential(
*[VitBlock(embd_dim,num_heads)
for _ in range(depth)]
)

self.norm = nn.LayerNorm(embd_dim)


def forward(self,x):

B = x.size(0)

x = self.path_embedding(x)

cls = self.cls.expand(B,-1,-1)

# Corrected: concatenate cls token at the beginning of the sequence
x = torch.cat((cls, x),dim=1)

x = x + self.pos_path

# Apply blocks
for block in self.blocks:
x = block(x)

out = self.norm(x[:,0]) # Take the class token representation
return F.normalize(out,dim=-1)


#---creat Text_encoder---#
class TextEncoder(nn.Module):

def __init__(self,embd_dim,vocab_size):
super().__init__()

self.embedding = nn.Embedding(vocab_size,embd_dim)


def forward(self,x): # x is expected to be a tensor of token indices (batch_size, sequence_length)

embd = self.embedding(x) # Output (batch_size, sequence_length, embd_dim)
# Assuming we want to average embeddings across the sequence length for a single text representation
embd = embd.mean(dim=1) # Output (batch_size, embd_dim)
x = F.normalize(embd,dim=-1)
return x


#--creat mini clip---#
class MiniCLIP(nn.Module):

def __init__(self,image_encoder_instance,text_encoder_instance,embd_dim):
super().__init__()

self.ViTimage_encoder = image_encoder_instance
self.text_encoder = text_encoder_instance
# Corrected: logits_scale initialization and requires_grad
self.logits_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))
# The original paper uses a learnable temperature parameter initialized to log(1/0.07)
# It's usually a scalar, not a tensor of shape []. Simplified for common use.
# Original paper initializes it as learnable logit_scale parameter

def forward(self,image_input,text_input):

image_features = self.ViTimage_encoder(image_input)
text_features = self.text_encoder(text_input)

image_features = F.normalize(image_features,dim=-1)
text_features = F.normalize(text_features,dim=-1)

scale = self.logits_scale.exp()
logits = scale * image_features @ text_features.T
return logits


def lossClip(logits):

# Corrected: Missing parenthesis for size(0)
labels = torch.arange(logits.size(0)).to(logits.device)

i_loss = F.cross_entropy(logits,labels)
# Corrected: apply cross_entropy on transposed logits for text-to-image similarity
t_loss = F.cross_entropy(logits.T,labels)

out = (i_loss + t_loss) / 2
return out


#------Creat Model------#
embd_dim = 128
vocab_size= 512

# Corrected: Instantiate encoders before passing them to MiniCLIP
# Changed num_heads from default 3 to 8, as 128 is not divisible by 3.
image_encoder_instance = VitImageEncoder(embd_dim=embd_dim, num_heads=8) # Assuming default img_size, path_size, depth
text_encoder_instance = TextEncoder(embd_dim,vocab_size)

# Corrected: Pass instances to MiniCLIP constructor
model = MiniCLIP(image_encoder_instance,text_encoder_instance,embd_dim)