Skip to content

Latest commit

 

History

History
1496 lines (1171 loc) · 38.8 KB

File metadata and controls

1496 lines (1171 loc) · 38.8 KB

Training the CNN

Tangles

  • The time line for these tangles have been over.
  • No more updates from this org.
  • The development is getting slower.
  • The details are preserved here for historical reason.
  • The name of each tangle has been preserved but without possiblity of a tangle.

Functions

functions.py

<<imports>>
<<fn_imports>>
<<functions>>

if __name__ == "__main__" :
  <<main_imports>>
  <<function_tests>>

  pass

Networks

networks.py

<<imports>>
<<nw_imports>>

<<networks>>

if __name__ == "__main__" :
  <<main_imports>>
  <<network_test>>
  pass

Losses

losses.py

<<imports>>
<<loss_imports>>

<<losses>>

if __name__ == "__main__" :
  <<main_imports>>
  <<loss_test>>
  pass

Trainer

trainer.py

<<imports>>
<<tr_imports>>

<<trainer>>

if __name__ == "__main__" :
  <<main_imports>>
  <<tr_test>>
  pass

Reporter

reporter.py

<<import>>

<<report_imports>>

<<reporter>>

if __name__ == '__main__' :
  <<main_imports>>
  <<report_test>>
  pass

Data

dataset.py

<<cd_include>>
<<cd_class>>
<<cd_functions>>
if __name__ == '__main__' :
  <<cd_main>>

Imports

import logging as lg
lg.basicConfig(level=lg.DEBUG, format="%(levelname)-8s: %(message)s")

import numpy as np
import yajl
from argparse import Namespace

Options

def load_options(options_file) :
  with open(options_file, 'r') as J :
    options = yajl.load(J)

  options = Namespace(**options)
  return options

Usage:

options_file = "/home/bvr/code/rivet/sample_options.json"
options = load_options(options_file)

lg.info(options)

Distance Functions

Pytorch has

Euclidean Distance
torch.nn.PairwiseDistance
Cosine Similarity
torch.nn.CosineSimilarity

Loss Functions

Pytorch has:

L1 Loss
torch.nn.L1Loss
Classical mean squared error loss
torch.nn.MSELoss
Cross Entropy
torch.nn.CrossEntropyLoss
Negative Log Likelihood Loss
torch.nn.NLLLoss
KLDivergence Loss
torch.nn.KLDivLoss
Binary Cross Entropy
torch.nn.BCELoss
ditto (with logits)
torch.nn.BCEWithLogitsLoss
Triplet Loss
torch.nn.TripletMarginLoss

The train.py as earlier

import logging as lg
lg.basicConfig(level=lg.INFO, format='%(levelname)-8s: %(message)s')

import argparse
import os
import shutil
import time
import math
import yajl

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader

from combinations_dataset import combinations_dataset as CD
from contrastive_loss import ContrastiveLoss

from helpers import resnet_siamese, bvr_normalize, siamese_accuracy
from helpers import inspect_tensor, adjust_learning_rate, AverageMeter
from helpers import save_checkpoint, validate, train, fix_transfer
from helpers import transfer_weights

from parser2 import parser

best_prec1 = 0

def main():
  global best_prec1
  args = parser.parse_args()

  lg.info('Entered main. Args:\n  %s', str(vars(args)).replace(", '", "\n   '"))

  # model here
  fc_layers = [128, 1]
  model = resnet_siamese(
    torchvision.models.resnet18(), fc_layers).cuda()
  lg.info('Created model resnet18:\n  %s',
          str(model).replace('\n', '\n   '))

  # define loss function (criterion) and optimizer
  criterion = ContrastiveLoss(margin=args.spring_margin)

  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  accuracy = siamese_accuracy(margin=args.spring_margin)

  lg.info('Defined loss function and optimizer.')

  # optionally resume from a checkpoint
  # if args.resume:
  #   if os.path.isfile(args.resume):
  #     lg.info("Loading checkpoint '{}'".format(args.resume))
  #     checkpoint = torch.load(args.resume)
  #     args.start_epoch = checkpoint['epoch']
  #     best_prec1 = checkpoint['best_prec1']
  #     model.load_state_dict(checkpoint['state_dict'])
  #     optimizer.load_state_dict(checkpoint['optimizer'])
  #     lg.info("Loaded checkpoint '{}' (epoch {})"
  #             .format(args.resume, checkpoint['epoch']))
  #   else:
  #     lg.warn("No checkpoint found at '{}'".format(args.resume))
  #     lg.warn('STARTING FROM SCRATCH!!! ')

  # elif args.tr :
  if args.tr :
    lg.info('Transfer Learning...')
    if os.path.isfile(args.tr) :
      lg.info('Loading pretrained weights: %s', args.tr)
      transfer_weights(model, args.tr)
      lg.info('Loaded pretrained weights: %s', args.tr)
    else :
      lg.warn('No checkpoint found at `%s\'', args.tr)
      lg.warn('STARTING FROM SCRATCH!!! ')

  cudnn.benchmark = True

  ## Data Transformations 
  T = transforms.Compose([
    transforms.Grayscale(3),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.Resize(224),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 255 - x)
  ])

  # ## Data Loading and Testing
  # if args.evaluate:

  #   with open(args.data, 'r') as F :
  #     val_loader = yajl.load(F)['test']

  #   val_loader = DataLoader(
  #     ds_sketches(val_loader, transform=T, fac_pos=args.prob_similar),
  #     batch_size = args.batch_size,
  #     shuffle = False,
  #     num_workers = args.num_workers
  #   )

  #   validate(val_loader, model, criterion, accuracy, args)
  #   return

  ## Data Loading for Training 
  with open(args.combinations_json, 'r') as J :
    combinations_list = yajl.load(J)
  lg.info('Loaded combinations_list: size:%s',
          len(combinations_list))

  with open(args.images_list_json, 'r') as J :
    images_list = yajl.load(J)[args.images_list_key]
  lg.info('Loaded image_list: size:%s', len(images_list))

  train_loader = DataLoader(
    CD(combinations_list, images_list,
       transform=T),# fac_pos=args.prob_similar),
    batch_size = args.batch_size,
    shuffle = True,
    num_workers = args.num_workers
  )
  # val_loader = DataLoader(
  #   ds_sketches(data['val'], transform=T, fac_pos=args.prob_similar),
  #   batch_size = args.batch_size,
  #   shuffle = False,
  #   num_workers = args.num_workers
  # )

  ## Checkpoint Setup
  save_path = os.path.join(args.save_dir, args.save_filename)
  save_path_best = os.path.join(args.save_dir, 'model_best.pth.tar')
  lg.info('Saving training checkpoints at: %s', save_path)

  ## Traning and Validation
  for epoch in range(args.start_epoch, args.epochs):
    # if args.distributed:
    #   train_sampler.set_epoch(epoch)
    adjust_learning_rate(optimizer, epoch, args.lr, args.lr_decay)

    # fix_transfer
    if args.tr and args.tr_fix > 0 :
      fix_transfer(model, epoch, args.tr_fix)

    # train for one epoch
    prec1 = train(train_loader, model, criterion, optimizer, accuracy, epoch, args)

    # evaluate on validation set
    # prec1 = validate(val_loader, model, criterion, accuracy, args)

    # remember best prec@1 and save checkpoint
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    save_checkpoint(
      { 'epoch': epoch + 1,
        'arch': 'siamese(resnet18+FC%s)' % (tuple(fc_layers),),
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer' : optimizer.state_dict(),
      },
      is_best,
      filename=save_path,
      save_path_best=save_path_best)
    lg.info('Saving Checkpoint... Done')

if __name__ == '__main__':
  main()

Pytorch Identity Module

from torch.nn import Module
class Identity(Module) :
  def forward(self, inputs) :
    return inputs

lg.info(Identity())
 

The Networks

class resnet_base(nn.Module) :
  def __init__(self, resnet, weights_file, fc=Identity) :
    super().__init__()
    self.resnet = resnet
    self.fc = fc
    self.weights_file = weights_file
    self.load_weights()
    self.assign_fc()

  def load_weights(self) :
    weights = torch.load(self.weights_file)
    pretrained = weigths['state_dict']
    pretrained = {k: pretrained[k]
                  for k in pretrained
                  if 'fc' not in k}

    resnet = models.resnet18()
    resnet.load_state_dict(pretrained, strict=False)
    resnet.eval()
    resnet.cuda()

    self.resnet = resnet

  def assign_fc(self) :
    fc_out = self.fc
    fc_in = [512] + fc_out[:-1]
    self.resnet.fc = torch.nn.Sequential(*[
      torch.nn.Linear(in_size, out_size)
      for in_size, out_size in zip(fc_in, fc_out)
    ])

class resnet_feature_pair(resnet_base) :
  def __init__(self, *args, **kwargs) :
    super().__init__(*args, **kwargs)

  def forward(self, inputs) :
    x1, x2 = inputs
    return self.resnet(x1), self.resnet(x2)

class resnet_feature_triple(nn.Module) :
  def __init__(self, *args, **kwargs) :
    super().__init__()
    self.network_setup()

  def network_setup(self, *args, **kwargs) :
    self.resnet = resnet_feature_pair(*args, **kwargs).resnet

  def forward(self, inputs) :
    x, x_plus, x_minus = inputs
    return (self.resnet(x),
            self.resnet(x_plus),
            self.resnet(x_minus))

class resnet_concat_pair(resnet_base) :
  def __init__(self, *args, **kwargs) :
    super().__init__(*args, **kwargs)

  def forward(self, intputs) :
    pass

class resnet_concat_triple(nn.Module) :
  def __init__(self, *args, **kwargs) :
    super().__init__()
    self.network_setup(*args, **kwargs)

  def network_setup(self, *args, **kwargs) :
    pass

  def forward(self, inputs) :
    pass

The Networks2

import torch
from torch import nn
class pair_feat(nn.Module) :
  def __init__(self, network, **kwargs) :
    super().__init__()
    self.network = network

  def forward(self, inputs) :
    x1, x2 = inputs
    return torch.stack((self.network(x1), self.network(x2)))

class triple_feat(nn.Module) :
  def __init__(self, network, **kwargs) :
    super().__init__()
    self.network = network

  def forward(self, inputs) :
    x, x_pos, x_neg = inputs
    return self.network(x), self.network(x_pos), self.network(x_neg)

class pair_concat(nn.Module) :
  def __init__(self, network, fc,
               feat_length=512) :
    self.network = network
    fc_out = fc
    fc_in = [feat_length] + fc_out[:-1]
    self.fc = nn.Sequential(*[
      nn.Linear(in_size, out_size)
      for in_size, out_size in zip(fc_in, fc_out)
    ])

  def forward(self, inputs) :
    x1, x2 = inputs
    y1, y2 = self.network(x1), self.network(x2)
    new_x = torch.concat(y1, y2)
    return self.fc(new_x)
class triple_concat(nn.Module) :
  def __init__(self, *args, **kwargs) :
    super().__init__()
    self.network = pair_concat(*args, **kwargs)

  def forward(self, inputs) :
    x, x_pos, x_neg = inputs
    y_pos_hat = self.network((x, x_pos))
    y_neg_hat = self.network((x, x_neg))

    return (y_pos_hat, y_neg_hat)

The Resnet

import torch
import torchvision.models as models
from torch import nn
def pretrained_resnet(weights_file, fc=Identity()) :
  weights = torch.load(weights_file)
  pretrained = weights['state_dict']
  pretrained = {k: pretrained[k]
                  for k in pretrained
                  if 'fc' not in k}

  resnet = models.resnet18()
  resnet.load_state_dict(pretrained, strict=False)
  resnet.fc = fc
  resnet.eval()
  resnet.cuda()

  return resnet

def create_fc(layers=[128, 1]) :
  if layers is None :
    return Identity()

  fc_out = layers
  fc_in = [512] + fc_out[:-1]

  return nn.Sequential(*[
    nn.Linear(in_size, out_size)
    for (in_size, out_size) in zip(fc_in, fc_out)
  ])

The Losses

import logging as lg

import torch
import torch.nn.functional as F

class DistancePowerN(torch.nn.PairwiseDistance) :
  def __init__(self, power=2, **kwargs) :
    super().__init__(**kwargs)
    self.power = power
    # lg.info("DistancePowerN isinstance torch.nn.Module: %s",
    #         isinstance(self, torch.nn.Module))

  def forward(self, x1, x2) :
    d = super().forward(x1, x2)

    if self.power != 1 :
      d = torch.pow(d, self.power)

    return d

## Copied from this gist:
#  https://gist.github.com/harveyslash/725fcc68df112980328951b3426c0e0b

## Modified to include a distance measure

class ContrastiveLoss(torch.nn.Module):
  """
  Contrastive loss function.
  Based on:
  http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

  The distance measure is parametrised here. By default uses 
  squared euclidean distance. Another alternative may be 
  torch.nn.KLDivLoss - The KLDivergence as the distance metric.

  Labels are binary {0, 1}; 0 means similar, and 1 means dissimilar.

  """

  def __init__(self, margin=4.0,
               distance_fn=DistancePowerN(2)):
    super().__init__()

    # lg.info("ContrastiveLoss isinstance torch.nn.Module: %s", 
    #         isinstance(self, torch.nn.Module))

    self.distance_fn = distance_fn
    self.margin = margin

  def forward(self, outputs, label):
    # lg.info("len(outputs): %s", len(outputs))
    distance = self.distance_fn(*outputs)
    loss_contrastive = torch.mean(
      (1-label) * distance + 
      (label) * torch.clamp(self.margin - distance, min=0.0))

    return loss_contrastive

  def interpret(self, _Y) :
    d = self.distance_fn(*_Y)
    # lg.info("ContrastiveLoss.interpret: d.size(): %s", d.size())
    return (d > self.margin).float()

class TripletLoss(ContrastiveLoss) :
  '''
  A sum of Losses over each pair of positive and negative pairs

  The pair loss e.g. ContrastiveLoss
  '''

  def __init__(self, *args,
               label={
                 'pos': 0,
                 'neg': 1
               },
               **kwargs):
    super().__init__(*args, **kwargs)
    self.label=label

  def forward(self, output, label=None) :
    if label is None :
      y_pos, y_neg = self.label

    y_pos, y_neg = label
    _y, _y_pos, _y_neg = output

    return (super().forward((y, y_pos), y_pos)
            + super().forward((y, y_neg), y_neg))

def npy_var(x, volatile=False, cuda=True) :

  X = torch.from_numpy(label)
  if cuda :
    X = X.cuda()

  return torch.autograd.Variable(X)

class BCETripletLoss(torch.nn.BCELoss) :
  def __init__(self, *args,
               label=[
                 np.array([1, 0], dtype=np.float),
                 np.array([0, 1], dtype=np.float)
               ], **kwargs) :
    super().__init__(*args, **kwargs)
    self.label = [npy_var(l) for l in label]

  def forward(self, _Y, Y=None) :
    if Y is None:
      Y = self.label

    Y_pos, Y_neg = Y
    _Y_pos, _Y_neg = _Y

    return (super().forward(_Y_pos, Y) 
            + super().forward(_Y_neg, Y))

The Data

Pairwise

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

import random
from graph_algo import uniq_edges
import numpy as np

class pairwise_dataset(Dataset) :
  '''Uses image_list and adjacency_list for similar pairs. For each
  image in similar_pair, randomly generates a dissimilar pair. (1:2)
  positive to negative samples.

  Labels may be initialized as an ordered pair: 0: similar, 1: dissimilar
  '''

  def __init__(self, adjacency, image_list,
               labels=[0, 1],
               transform = None,
               dissimilar_fn = None) :

    self.adjacency = adjacency
    self.image_list = image_list
    self.labels = torch.tensor(labels).float()

    self.transform = transform

    self.dissimilar = dissimilar_fn
    if self.dissimilar is None :
      self.dissimilar = self.find_dissimilar

    self.init_pairs()

  def init_pairs(self) :
    pairs = uniq_edges(self.adjacency) #gives me a numpy array (N, 2)
    flat_pairs = pairs.reshape([-1])
    undef = np.full_like(flat_pairs, -1)
    more_pairs = np.stack([flat_pairs, undef], axis=1)
    self.pairs = np.concatenate([pairs, more_pairs], axis=0)

  def __len__(self):
    return len(self.pairs)

  def __getitem__(self, index) :
    x1, x2 = self.pairs[index]
    y = int(x2 == -1)
    if y != 0 :
      x2 = self.dissimilar(x1)

    y = self.labels[y]

    # lg.info((x1, x2))
    x1 = self.load_image(x1)
    x2 = self.load_image(x2)

    if self.transform :
      x1 = self.transform(x1)
      x2 = self.transform(x2)

    return y, (x1, x2)

  def find_dissimilar(self, index) :
    indices = set((int(i) for i in self.adjacency.keys()))

    # lg.info("indices(%d): %s", len(list(indices)), sorted(list(indices)))
    # lg.info("index: %s", index)
    # lg.info("adjacency(%d): %s", index, self.adjacency[str(index)])

    indices = indices - set(self.adjacency[str(index)] + [int(index)])
    indices = list(indices)
    # lg.info("indices(%d): %s", len(indices)), sorted(indices))

    return random.choice(indices)

  def load_image(self, image_index) :
    image_name = self.image_list[image_index]
    return Image.open(image_name)

Triplet

class triplet_dataset(pairwise_dataset) :
  '''Uses image_list and adjacency_list for similar pairs. For each
  image in similar_pair, randomly generates a dissimilar pair. (1:2)
  positive to negative samples.

  '''

  def __init__(self, *args, **kwargs) :
    super().__init__(*args, **kwargs)

  def init_pairs(self) :
    self.pairs = uniq_edges(self.adjacency) #gives me a numpy array (N, 2)

  def __len__(self):
    return 2 * self.pairs.shape[0]

  def __getitem__(self, index) :
    i = index // self.pairs.shape[0]
    index = index % self.pairs.shape[0]

    if i > 0:
      x_pos, x = self.pairs[index]
    else :
      x, x_pos = self.pairs[index]

    x_neg = self.dissimilar(x)

    x = self.load_image(x)
    x_pos = self.load_image(x_pos)
    x_neg = self.load_image(x_neg)

    if self.transform :
      x = self.transform(x)
      x_pos = self.transform(x_pos)
      x_neg = self.transform(x_neg)

    return self.labels, (x, x_pos, x_neg)


Test this (Include test for triplet)

<2018-05-27 Sun 13:12>

# To Test
# -----------------------------------
# combinations_dataset(similar_pairs, image_list,
#                      transform = None,
#                      dissimilar_fn = None)

# Logging:
# -----------------------------------
import logging as lg
lg.basicConfig(level=lg.INFO, format='%(levelname)-8s: %(message)s')

from graph_algo import edge_to_adjacency

# With transforms
# -----------------------------------
from torchvision.transforms import Compose, Grayscale, ToTensor
from torchvision.transforms import Resize, RandomCrop
T = Compose([Grayscale(), Resize(224), RandomCrop(224), ToTensor()])

## Json Loader
# -----------------------------------
import yajl

combinations_json = '/home/bvr/data/pytosine/k_nearest/20180526-153919/combinations.json'
with open(combinations_json, 'r') as J :
  similar_pairs = yajl.load(J)['combinations']
lg.info('Loaded similar pairs: size:%s', len(similar_pairs))

adjacency = edge_to_adjacency(similar_pairs)
# TODO: include edge_to_adjacency before tangle

image_list_json = '/home/bvr/data/pytosine/k_nearest/20180521-205730/image_list.json'
image_list_key = 'image_list'
with open(image_list_json, 'r') as J :
  image_list = yajl.load(J)[image_list_key]
lg.info('Loaded image_list: size:%s', len(image_list))

def test_dataset(dataset_name) :
  global adjacency, image_list, T

  dataset = dataset_name(
    adjacency, image_list,
    transform = T,
    labels=[np.array([1, 0]), np.array([0, 1])])

  dataloader = DataLoader(
    dataset, shuffle=True, batch_size = 64
  )

  for i, (y, x) in enumerate(dataloader) :
    lg.info('sizes: len(y), y[0].size, len(x), x[0].size: %s, %s, %s, %s',
            len(y), y[0].size(), len(x), x[0].size())


test_dataset(pairwise_dataset)
test_dataset(triplet_dataset)

from functools import reduce
import operator
def flatten(inp_list) :
  return reduce(operator.concat, inp_list)

Wrapper

import yajl
class Create(object) :
  def __init__(self, base_module) :
    self.base_module = base_module

  def __call__(self, adjacency, image_list, labels, transform):
    with open(adjacency, 'r') as J :
      adjacency = yajl.load(J)

    with open(image_list, 'r') as J :
      image_list = yajl.load(J)['image_list']

    transforms = {
      'sketch_transform': sketch_transform
    }
    transform = transforms.get(transform, sketch_transform)()

    return self.base_module(adjacency, image_list, labels, transform)

Pre Processing

from torchvision import transforms
def sketch_transform() :
  return transforms.Compose([
    transforms.Grayscale(3),
    transforms.Resize(224),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 255 - x)
  ])

The Trainer

import torch
from argparse import Namespace
from reporter import BvrReporter
from functions import BvrAccuracy, StopWatch, BvrSaver

stat_dtype = [('index', 'i4', (2,)),
              ('loss', 'f4'),
              ('accuracy', 'f4'),
              ('t_data', 'f4'),
              ('t_batch', 'f4')]

class Trainer :

  var = None
  reporter = None
  saver = None

  def __init__(self, data, network,
               loss, optimizer,
               reporter=BvrReporter(stat_dtype),
               accuracy=BvrAccuracy(), # TODO: Create nn.Module
               saver=BvrSaver(),
               lr_adjuster=None,
               net_adjuster=None,
               var=None,
               mode='train',
               options=Namespace(
                 num_epochs=1,
                 report_frequency=100,
                 save_frequency=1
               )) :
    '''
    To use a python dict for options, use options=Namespace(**py_dict)
    '''
    ## Mandatory
    self.data = data
    self.network = network
    self.loss = loss
    self.optimizer = optimizer

    ## Misc k:v pairs
    self.options = options

    ## Function / Value based Options
    self.accuracy = accuracy
    self.mode=mode
    self.var = var

    ## Class based Options 
    if reporter :
      self.reporter = reporter

    self.lr_modify = lr_adjuster
    self.net_modify = net_adjuster

    ## Stats
    self.stats = np.ndarray((self.options.report_frequency,),
                            dtype=stat_dtype)

    ## Saver
    if saver :
      self.saver = saver

  def is_eval_mode(self) :
    return self.mode == 'eval'

  def is_train_mode(self) :
    return self.mode == 'train'

  def to_var(self, X, vol=None) :
    if vol is None :
      vol = self.is_eval_mode()

    if isinstance(X, torch.Tensor) :
      return torch.autograd.Variable(
        X.cuda(async=True), volatile=vol
      )

    if isinstance(X, list) :
      return [self.to_var(x, vol) for x in X]

    raise TypeError("X in neither a Tensor nor a list of Tensors.")

  def stat_names(self) :
    return [stat[0] for stat in stat_dtype]

  def train(self) :
    opt = self.options

    for j in range(opt.num_epochs) :

      # adjust learning rate
      if self.lr_modify :
        self.lr_modify.step(j)

      # adjust layerwise training
      if self.net_modify :
        self.net_modify(j)

      stop_watch = StopWatch()
      stop_watch.start()

      for i, data in enumerate(self.data) :
        ii = self.train_1((j, i), data, stop_watch)

      if self.saver :
        i1 = self.reporter.cursor
        i0 = i1 - len(self.data)
        self.saver(self.network, self.reporter.stats, (i0, i1))

  def train_1(self, idx, data, stop_watch) :
    opt = self.options
    i_max = len(self.data)
    i = idx[-1]

    ## Create variables
    if self.var :
      Y, X = self.var(data, self.to_var, self.is_eval_mode())
    else :
      Y, X = data
      Y, X = self.to_var(Y, vol=True), self.to_var(X)

    ## Data Timer
    t_data = stop_watch.record()

    ## Forward Pass
    _Y = self.network(X)

    ## Loss and Accuracy
    loss = self.loss(_Y, Y)
    accuracy = self.accuracy(_Y, Y)

    ## Backward Pass
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    ## Batch Timer
    t_batch = t_data + stop_watch.record()

    ## Record Stats
    s_i = i % opt.report_frequency
    self.stats[s_i] = (
      idx,
      loss.data,
      accuracy.data,
      t_data,
      t_batch
    )

    ## Report Stats
    ii = 1 + i
    if ii % opt.report_frequency == 0 or ii == i_max :
      self.reporter.report(self.stats[:ii])

The Reporter

import logging as lg
import numpy as np

# stat_dtype = [('index', 'i4', (2,)),
#               ('loss', 'f4'),
#               ('accuracy', 'f4'),
#               ('t_data', 'f4'),
#               ('t_batch', 'f4')]

def is_int(n) :
  try:
    int(str(n))
  except ValueError:
    return False

  return True      

def log_average(id_range, stats) :

  # lg.info(stats.dtype.fields)
  # lg.info(stats.dtype.names)
  # lg.info(stats.dtype.itemsize)

  i0, i1 = id_range
  stats = stats[i0:i1]

  indices = [
    "%s:%s" % (stats[name][0], stats[name][i1-1 - i0])
    for name in stats.dtype.names
    if 'index' in name
  ]

  summary = [
    "%s:(%s %c %s)" % (name,
                       np.mean(stats[name]),
                       chr(177),
                       np.std(stats[name]))
    for name in stats.dtype.names
    if 'index' not in name
  ]

  lg.info("%s %s", ' '.join(indices), ' '.join(summary))


def grapher(id_range, stats) :
  pass

class BvrReporter(object) :
  stats = None
  chunk_size = 1024

  queue = []

  def __init__(self, stat_dtype,
               chunk_size=None,
               queue=(
                 log_average,
               )) :

    if is_int(chunk_size) :
      self.chunk_size = int(str(chunk_size))

    self.stats = np.ndarray((self.chunk_size,), dtype=stat_dtype)
    self.cursor = 0

    self.queue = queue

  def extend(self) :
    np.concatenate((self.stats, np.empty_like(self.stats)))

  def report(self, stats) :
    i0, i1 = self.cursor, self.cursor + stats.shape[0]
    if i1 > self.stats.shape[0] :
      self.extend()

    self.stats[i0:i1] = stats

    for consume in self.queue :
      consume((i0, i1), self.stats)

KL Divergence Doc from PyTorch

From a PyTorch Discussion

KL divergence is a useful distance measure for continuous
distributions and is often useful when performing direct regression
over the space of (discretely sampled) continuous output
distributions.

As with NLLLoss, the input given is expected to contain
log-probabilities, however unlike ClassNLLLoss, input is not
restricted to a 2D Tensor, because the criterion is applied
element-wise.

This criterion expects a target Tensor of the same size as the input
Tensor.

Accuracy

import logging as lg
class BvrAccuracy(Module) :
  def __init__(self, transform=None):
    super().__init__()
    self.transform = transform
    # lg.info("accuracy: transform: %s", self.transform)

  def forward(self, _Y, Y) :
    # lg.info("size: _Y: %s", _Y.size())
    # lg.info("size: Y: %s", Y.size())

    if self.transform :
      # lg.info("callable(self.transform): %s", 
      #   callable(self.transform))
      # lg.info("transforming _Y")
      _Y = self.transform(_Y)

    return torch.mean((_Y == Y).float())

Stop Watch

import time
class StopWatch(object) :
  def __init__(self):
    self.start()

  def start(self) :
    self.time = time.time()

  def record(self) :
    old = self.time
    self.time = time.time()
    return self.time - old

Saver

import os
import shutil
from datetime import datetime as Dt

class BvrSaver(object) :
  options = Namespace(
      save_frequency = 1,
      save_location = '.',
      saver_current = 'checkpoint.pth.tar',
      saver_best = 'model_best.pth.tar'
  )

  def __init__(self, options=dict()) :
    self.count = 0
    self.best_prec = 0.

    self.update_options(options)

  def update_options(self, options) :

    if isinstance(options, Namespace) :
      options = vars(options)
    self.options = vars(self.options)

    # lg.info(self.options)
    # lg.info(options)
    self.options.update(options)
    # lg.info(self.options)

    self.options = Namespace(**self.options)

    self.update_locations()

  def update_locations(self) :
    self.options.save_location = os.path.join(
      self.options.save_location,
      BvrSaver.timestamp()
    )

    self.options.saver_current = os.path.join(
      self.options.save_location,
      self.options.saver_current
    )

    self.options.saver_best = os.path.join(
      self.options.save_location,
      self.options.saver_best
    )

  @staticmethod
  def timestamp() :
    return Dt.now().strftime('%Y%m%d-%H%M%S')

  def __call__(self, model, stats, idx_range) :
    self.count += 1
    if self.count < self.options.save_frequency :
      return

    if not os.path.isdir(self.options.save_location) :
      os.makedirs(self.options.save_location)

    lg.info("Saving Model")
    torch.save(model, self.options.saver_current)

    i0, i1 = idx_range
    prec = np.mean(stats[i0:i1]['accuracy'])

    lg.info("idx_range, prec, best_prec: %s, %s, %s", idx_range, prec, self.best_prec)
    if prec > self.best_prec :
      self.count = 0
      self.best_prec = prec
      shutil.copyfile(self.options.saver_current,
                      self.options.saver_best)
      lg.info("Saving best model")