salad.solver package

Model optimization by stochastic gradient descent and variants The general structure:

  • Each experiment configuration is a subclass of Solver or some derivative loss functions
  • Solvers only specify how data and models are used to generate the losses
  • Similarities between deep learning experiments (checkpointing, logging, …) are implemented in the Solver class.

In general, for many experiments, it makes sense to set up a solver as a subclass of a specific other solver; i.e. when the general problem is concerned with classifcation, a CrossEntropySolver would be a natural choice.

Classes where designed with the possibility of re-use in mind. The goal is to exploit the particular structure most deep learning experiments share.

Submodules

salad.solver.base module

Base classes for solvers

This module contains abstract base classes for the solvers used in salad.

class salad.solver.base.EventBasedSolver

Bases: object

Event handling for solvers

All solvers derived from the EventBasedSolver are extended with event handlers, currently for the following events:

  • start_epoch
  • start_batch
  • finish_batch
  • stop_epoch
finish_batch(*args, **kwargs)
finish_epoch(*args, **kwargs)
start_batch(*args, **kwargs)
start_epoch(*args, **kwargs)
class salad.solver.base.Solver(dataset, n_epochs=1, savedir='./log', gpu=None, dryrun=False, **kwargs)

Bases: salad.solver.base.EventBasedSolver, salad.solver.base.StructuredInit

General gradient descent solver for deep learning experiments

This is a helper class for training of PyTorch models that makes very little assumptions about the structure of a deep learning experiment. Solvers are generally constructed to take one or several models (torch.nn.Module) and one DataLoader instance that provides a (possibly nested) tuple of examples for training.

The Solver implements the following features:

  • Logging of losses and model checkpoints

While offering these functionality in the background, this class implementation aims at being very flexible when it comes to designing any kind of deep learning experiment.

When defining your own solver class, you should first

  • register models [register_model]
  • register loss functions [register_loss]
  • register optimizers [register_optimizer]

The abstraction goes as follows:

  • An experiment is fully characterized by its Solver class
  • An experiment can have multiple models
  • Parameters of the models are processed by optimizers
  • Optimizers have a functions to derive losses

In the optimization process, the following algorithm is used:

for opt in optimizers:
losses = L(opt) grad_losses = grad(losses) opt.step(grad_losses)
Parameters:
  • dataset (Dataset) – Dataset used for training
  • n_epochs (int) – Number of epochs (defined as full passes through the DataLoader)
  • savedir (str) – log directory for saving model checkpoints and the loss history
  • gpu (int) – Number of GPU to be used. If None, use CPU training instead
  • dryrun (bool) – Train only for the first batch. Useful for testing a new solver

Notes

After initializing all internal dictionaries, the constructor makes calls to the _init_models, _init_optims and _init_losses functions. If these functions should make use of any additional keyword arguments you passed in your class, make sure that you initialize them prior to calling super().__init__ in your constructor.

compute_loss_dict(loss_args)
cuda(obj)

Move nested iterables between CUDA or CPU

format_summary_report(losses)
format_train_report(losses)
optimize()

Start the optimization process

Notes

Prior to the optimization process, all models will be set to training mode by a call to model.train(True)

register_loss(func, weight=1.0, name=None, display=True, override=False, **kwargs)

Register a new loss function

Parameters:
  • func – pass
  • weight (float) –
register_model(model, name='')

Add a model to the solver

This method will also move the model directly to the correct device you specified when constructin the solver.

Parameters:
  • model (torch.nn.Module) – The model to be optimized. Should return a non-empty iterable when the paramters() method is called
  • name (str, optional) – Name for the model. Useful for checkpoints when multiple models are optimized.
register_optimizer(optimizer, loss_func, retain_graph=False, name='', n_steps=1)

Add an optimizer to the solver

Parameters:
  • optimizer (Optimizer) – A function used for updating model weights during training
  • loss_func (LossFunction) – A function (or callable object) that, given the current batch passed by the Solver’s data loader, returns a dictionary containing either a dictionary mapping loss function names to arguments, or a dictionary mapping loss function names to the loss.
  • retain_graph (bool) – True if the computational graph should be retained after calling the loss function associated to the optimizer. This is usually not needed.
  • name (str) – Optimizer name. Useful for logging
  • n_steps (int) – Number of consecutive steps the optimizer is exectued. Usually set to 1.
remove_loss(name)
class salad.solver.base.StructuredInit(**kwargs)

Bases: object

Structured Initialization of Solvers

Initializes the components of a solver and passes arguments. Initialization is done in the following order:

  • _init_models
  • _init_losses
  • _init_optims
Parameters:
  • kwargs (Keyword arguments) – Pass arguments for all initialization functions. Keyword arguments are passed through the functions in the order specified above. Unused keyword arguments will be printed afterwards. In general, solvers should be designed in a way that ensure that all keyword argumnets are used.
  • note (.) – Don’t instantiate or subclass this class directly.

salad.solver.da.base module

Solver classes for domain adaptation experiments

class salad.solver.da.base.BaselineDASolver(*args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

A domain adaptation solver that actually does not run any adaptation algorithm

This is useful to establish baseline results for the case of no adaptation, for measurement of the domain shift between datasets.

class salad.solver.da.base.DABaseSolver(*args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Base Class for Unsupervised Domain Adaptation Approaches

Unsupervised DA assumes the presence of a single source domain \(\mathcal S\) along with a target domain \(\mathcal T\) known at training time. Given a labeled sample of points drawn from \(\mathcal S\), \(\{x^s_i, y^s_i\}_{i}^{N_s}\), and an unlabeled sample drawn from \(\mathcal T\), \(\{x^t_i\}_{i}^{N_t}\), unsupervised adaptation aims at minimizing the

\[\min_\theta \mathcal{R}^l_{\mathcal S} (\theta) + \lambda \mathcal{R}^u_{\mathcal {S \times T}} (\theta),\]

leveraging an unsupervised risk term \(\mathcal{R}^u_{\mathcal {S \times T}} (\theta)\) that depends on feature representations \(f_\theta(x^s,s)\) and \(f_\theta(x^t,t)\), classifier labels \(h_\theta(x^s,s), h_\theta(x^t,t)\) as well as source labels \(y^s\). The full model \(h = g \circ f\) is a composition of a feature extractor \(f\) and classifier \(g\), both of which can possibly depend on the domain label \(s\) or \(t\) for domain-specific computations.

Notes

This solver adds two accuracies with keys acc_s and acc_t for the source and target domain, respectively. Make sure to include derivation of these accuracy in your loss computation.

class salad.solver.da.base.DABaselineLoss(solver)

Bases: object

class salad.solver.da.base.DATeacher(model, teacher, dataset, *args, **kwargs)

Bases: salad.solver.base.Solver

Base Class for Unsupervised Domain Adaptation Approaches using a teacher model

class salad.solver.da.base.DGBaseSolver(*args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Base Class for Domain Generalization Approaches

Domain generalization assumes the presence of multiple source domains alongside a target domain unknown at training time. Following cite{Shankar2018}, this setting requires a dataset of training examples \(\{x_i, y_i, d_i\}_{i}^{N}\) with class and domain labels. Importantly, the domains present at training time should reflect the kind of variability that can be expected during inference. The ERM problem is then approached as

\[\min_\theta \sum_d \mathcal{R}^l_{\mathcal S_d} (\theta) = \sum_d \lambda_d \mathbb{E}_{x,y \sim \mathcal S_d }[\ell ( f_\theta(x), h_\theta(x), y, d) ].\]

In contrast to the unsupervised setting, samples are now presented in a single batch comprised of inputs \(x\), labels \(y\) and domains \(d\). In a addition to a feature extractor \(f_\theta\) and classifier \(g_\theta\), models should also provide a feature extractor \(f^d_\theta\) to derive domain features along with a domain classifier \(g^d_\theta\), with possibly shared parameters.

In contrast to unsupervised DA, this training setting leverages information from multiple labeled source domains with the goal of generalizing well on data from a previously unseen domain during test time.

salad.solver.da.advdrop module

class salad.solver.da.advdrop.AdversarialDropoutLoss(model, step=1)

Bases: object

Loss Derivation for Adversarial Dropout Regularization

See also

salad.solver.AdversarialDropoutSolver

step1(batch)
step2(batch)
step3(batch)
class salad.solver.da.advdrop.AdversarialDropoutSolver(model, dataset, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Implementation of “Adversarial Dropout Regularization”

Adversarial Dropout Regulariation [1] estimates uncertainties about the classification process by sampling different models using dropout. On the source domain, a standard cross entropy loss is employed. On the target domain, two predictions are sampled from the model.

Both network parts are jointly trained on the source domain using the standard cross entropy loss,

..math:

\min_{C, G} - \sum_k p^s_k \log y^s_k

The classifier part of the network is trained to maximize the symmetric KL distance between two predictions. This distance is one option for measuring uncertainty in a network. In other words, the classifier aims at maximizing uncertainty given two noisy estimates of the current feature vector.

..math:

\min_{C} - \sum_k p^s_k \log y^s_k + \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}

In contrast, the feature extrator aims at minimizing the uncertainty between two noisy samples given a fixed target classifier.

..math:

\min_{G} \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}

References

[1] Adversarial Dropout Regularization, Saito et al., ICLR 2018

class salad.solver.da.advdrop.SymmetricKL

Bases: torch.nn.modules.module.Module

forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.solver.da.advdrop.pack(*args)
salad.solver.da.advdrop.unpack(arg, n_tensors)

salad.solver.da.association module

Associative Domain Adaptation

[Hausser et al., CVPR 2017](#)

class salad.solver.da.association.AssociationLoss(model)

Bases: object

Loss function for associative domain adaptation

Given a model, derive a function that computes arguments for the association loss.

class salad.solver.da.association.AssociativeSolver(model, dataset, learningrate, walker_weight=1.0, visit_weight=0.1, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Implementation of “Associative Domain Adaptation”

Associative Domain Adaptation [1] leverages a random walk based on feature similarity as a distance between source and target feature correlations. The algorithm is based on two loss functions that are added to the standard cross entropy loss on the source domain.

Given features for source and target domain, a kernel function is used to measure similiarity between both domains. The original implementation uses the scalar product between feature vectors, scaled by an exponential,

\[K_{ij} = k(x^s_i, x^t_j) = \exp(\langle x^s_i, x^t_j \rangle)\]

This kernel is then used to compute transition probabilities

\[p(x^t_j | x^s_i) = \frac{K_{ij}}{\sum_{l} K_{lj}}\]

and

\[p(x^s_k | x^t_j) = \frac{K_{jk}}{\sum_{l} K_{kl}}\]

to compute the roundtrip

\[p(x^s_k | x^s_i) = \sum_{j} p(x^s_k | x^t_j) p(x^t_j | x^s_i)\]

It is then required that

  1. WalkerLoss The roundtrip ends at a sample with the same class label, i.e., $y^s_i = y^s_k$
  2. VisitLoss Each target sample is visited with a certain probability

As one possible modification, different kernel functions could be used to measure similarity between the domains. With this solver, it is advised to use large sample sizes for the target domain and ensure that a sufficient number of source samples is available for each batch.

TODO: Possibly in the solver class, implement a functionality to aggregate batches to avoid memory issues.

Parameters:
  • model (torch.nn.Module) – A pytorch model to be trained by association
  • dataset (StackedDataset) – A dataset suitable for an unsupervised solver
  • learningrate (int) – TODO

References

[1] Associative Domain Adaptation, Häusser et al., CVPR 2017, https://arxiv.org/abs/1708.00938

salad.solver.da.coral module

Losses for Correlatin Alignment

Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: https://arxiv.org/pdf/1607.01719.pdf

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z

class salad.solver.da.coral.CentroidDistanceLossSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z and: https://arxiv.org/pdf/1705.08180.pdf

class salad.solver.da.coral.CentroidLoss(model)

Bases: object

class salad.solver.da.coral.CorrDistanceSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z

class salad.solver.da.coral.CorrelationDistanceLoss(model, n_steps_recompute=10, nullspace=False)

Bases: object

class salad.solver.da.coral.CorrelationDistanceSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

class salad.solver.da.coral.DeepCoralSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: [https://arxiv.org/pdf/1607.01719.pdf](https://arxiv.org/pdf/1607.01719.pdf)

Loss Functions:

\[\mathcal{L}(x^s, x^t) = \frac{1}{4d^2} \| C_s - C_t \|\]
class salad.solver.da.coral.DeepLogCoralSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z

salad.solver.da.crossgrad module

Cross Gradient Training

ICLR 2018

class salad.solver.da.crossgrad.CrossGradLoss(solver)

Bases: object

Cross Gradient Training

http://arxiv.org/abs/1804.10745

pertub(x, loss, eps=1e-05)
class salad.solver.da.crossgrad.CrossGradSolver(model, *args, **kwargs)

Bases: salad.solver.da.base.DGBaseSolver

Cross Gradient Optimizer

A domain generalization solver based on Cross Gradient Training [1].

..math:
p(y | x) = int_d p(y|x,d) p(d|x) dd
..math:
x_d = x + eps Nabla_y L(y) \ x_y = x + eps Nabla_d L(d)

References

[1](1, 2) Shankar et al., Generalizing Across Domains via Cross-Gradient Training, ICLR 2018
class salad.solver.da.crossgrad.Model(n_classes, n_domains)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_domain(x)
parameters_classifier()
parameters_domain()
class salad.solver.da.crossgrad.MultiDomainModule(n_domains)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_domain(x)
parameters_classifier()
parameters_domain()
salad.solver.da.crossgrad.concat(x, z)

Concat 4D tensor with expanded 2D tensor

salad.solver.da.crossgrad.conv2d(m, n, k, act=True)
salad.solver.da.crossgrad.features(inp)
salad.solver.da.crossgrad.get_dataset(noisemodels, batch_size, shuffle=True, num_workers=0, which='train')

salad.solver.da.dann module

class salad.solver.da.dann.AdversarialLoss(G, D, train_G=True)

Bases: object

class salad.solver.da.dann.DANNSolver(model, discriminator, dataset, learningrate, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Domain Adversarial Neural Networks Solver

This builds upon the normal classification solver that uses CrossEntropy or BinaryCrossEntropy for optimizing neural networks.

salad.solver.da.dirtt module

class salad.solver.da.dirtt.DIRTT(model, teacher)

Bases: object

class salad.solver.da.dirtt.DIRTTSolver(model, teacher, dataset, *args, **kwargs)

Bases: salad.solver.base.Solver

Virtual Adversarial Domain Adaptation

class salad.solver.da.dirtt.VADA(G, D, train_G=True)

Bases: salad.solver.da.dann.AdversarialLoss

class salad.solver.da.dirtt.VADASolver(model, discriminator, dataset, *args, **kwargs)

Bases: salad.solver.da.dann.DANNSolver

Virtual Adversarial Domain Adaptation

salad.solver.da.dirtt_re module

Self Ensembling for Visual Domain Adaptation

class salad.solver.da.dirtt_re.DIRTT(model, teacher)

Bases: object

class salad.solver.da.dirtt_re.DIRTTSolver(model, teacher, dataset, learningrate, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

salad.solver.da.djdot module

class salad.solver.da.djdot.DJDOTSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Deep Joint Optimal Transport solver

TODO

derive_losses(batch)

salad.solver.da.ensembling module

Self Ensembling for Visual Domain Adaptation

class salad.solver.da.ensembling.EnsemblingLoss(model, teacher)

Bases: object

class salad.solver.da.ensembling.SelfEnsemblingSolver(model, teacher, dataset, learningrate, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

salad.solver.classification module

class salad.solver.classification.BCESolver(*args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Solver for a classification experiment

class salad.solver.classification.BaseClassSolver(model, dataset, multiclass=True, *args, **kwargs)

Bases: salad.solver.base.Solver

Base Solver for classification experiments

Parameters:
  • model (nn.Module) – A model to train on a classification target
  • dataset (torch.utils.data.Dataset) – The dataset providing training samples
  • multiclass (bool) – If True, CrossEntropyLoss is used, BCEWithLogitsLoss otherwise.
class salad.solver.classification.ClassificationLoss(solver)

Bases: object

class salad.solver.classification.FinetuneSolver(*args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Finetune a pre-trained deep learning models

Given a model with separable feature extractor and classifier, use different learning rates and regularization settings. Useful for fine-tuning pre-trained ImageNet models or finetuning of saved model checkpoints

Parameters:
  • model (nn.Module) – Module with two separate parts
  • dataset (Dataset) – The dataset used for training
class salad.solver.classification.MultiDomainClassificationLoss(solver, domain)

Bases: object

class salad.solver.classification.MultidomainBCESolver(model, dataset, learningrate, multiclass=True, loss_weights=None, *args, **kwargs)

Bases: salad.solver.base.Solver

salad.solver.gan module

Tools for training Generative Adversarial Networks (GANs)

The class is primarily used to train conditional networks (CGANs).

Notes

Contributions for extensions wanted!

class salad.solver.gan.CGANLoss(solver, G, Ds, train_G)

Bases: object

Loss Derivation for a Conditional GAN

derive_losses(batch)
class salad.solver.gan.ConditionalGANSolver(model, dataset, learningrate=0.0002, *args, **kwargs)

Bases: salad.solver.base.Solver

format_train_report(losses)
n_classes = [1, 11, 3]
names = ['D_GAN', 'D_CL', 'D_CON']
class salad.solver.gan.GANSolver(dataset, n_epochs=1, savedir='./log', gpu=None, dryrun=False, **kwargs)

Bases: salad.solver.base.Solver

salad.solver.openset module

Routines for open set classification

class salad.solver.openset.BaseOpensetSolver(model, dataset, multiclass=True, *args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

class salad.solver.openset.VADAOpenset(G, D, train_G=True)

Bases: salad.solver.da.dann.AdversarialLoss