salad.solver packageΒΆ
Model optimization by stochastic gradient descent and variants The general structure:
- Each experiment configuration is a subclass of Solver or some derivative loss functions
- Solvers only specify how data and models are used to generate the losses
- Similarities between deep learning experiments (checkpointing, logging, β¦) are implemented in the Solver class.
In general, for many experiments, it makes sense to set up a solver as a subclass of a specific other solver; i.e. when the general problem is concerned with classifcation, a CrossEntropySolver would be a natural choice.
Classes where designed with the possibility of re-use in mind. The goal is to exploit the particular structure most deep learning experiments share.
SubmodulesΒΆ
salad.solver.base moduleΒΆ
Base classes for solvers
This module contains abstract base classes for the solvers used in salad
.
-
class
salad.solver.base.
EventBasedSolver
ΒΆ Bases:
object
Event handling for solvers
All solvers derived from the
EventBasedSolver
are extended with event handlers, currently for the following events:start_epoch
start_batch
finish_batch
stop_epoch
-
finish_batch
(*args, **kwargs)ΒΆ
-
finish_epoch
(*args, **kwargs)ΒΆ
-
start_batch
(*args, **kwargs)ΒΆ
-
start_epoch
(*args, **kwargs)ΒΆ
-
class
salad.solver.base.
Solver
(dataset, n_epochs=1, savedir='./log', gpu=None, dryrun=False, **kwargs)ΒΆ Bases:
salad.solver.base.EventBasedSolver
,salad.solver.base.StructuredInit
General gradient descent solver for deep learning experiments
This is a helper class for training of PyTorch models that makes very little assumptions about the structure of a deep learning experiment. Solvers are generally constructed to take one or several models (torch.nn.Module) and one DataLoader instance that provides a (possibly nested) tuple of examples for training.
The
Solver
implements the following features:- Logging of losses and model checkpoints
While offering these functionality in the background, this class implementation aims at being very flexible when it comes to designing any kind of deep learning experiment.
When defining your own solver class, you should first
- register models [register_model]
- register loss functions [register_loss]
- register optimizers [register_optimizer]
The abstraction goes as follows:
- An experiment is fully characterized by its Solver class
- An experiment can have multiple models
- Parameters of the models are processed by optimizers
- Optimizers have a functions to derive losses
In the optimization process, the following algorithm is used:
- for opt in optimizers:
- losses = L(opt) grad_losses = grad(losses) opt.step(grad_losses)
Parameters: - dataset (Dataset) β Dataset used for training
- n_epochs (int) β Number of epochs (defined as full passes through the
DataLoader
) - savedir (str) β log directory for saving model checkpoints and the loss history
- gpu (int) β Number of GPU to be used. If
None
, use CPU training instead - dryrun (bool) β Train only for the first batch. Useful for testing a new solver
Notes
After initializing all internal dictionaries, the constructor makes calls to the
_init_models
,_init_optims
and_init_losses
functions. If these functions should make use of any additional keyword arguments you passed in your class, make sure that you initialize them prior to callingsuper().__init__
in your constructor.-
compute_loss_dict
(loss_args)ΒΆ
-
cuda
(obj)ΒΆ Move nested iterables between CUDA or CPU
-
format_summary_report
(losses)ΒΆ
-
format_train_report
(losses)ΒΆ
-
optimize
()ΒΆ Start the optimization process
Notes
Prior to the optimization process, all models will be set to training mode by a call to
model.train(True)
-
register_loss
(func, weight=1.0, name=None, display=True, override=False, **kwargs)ΒΆ Register a new loss function
Parameters: - func β pass
- weight (float) β
-
register_model
(model, name='')ΒΆ Add a model to the solver
This method will also move the model directly to the correct device you specified when constructin the solver.
Parameters: - model (torch.nn.Module) β The model to be optimized. Should return a non-empty iterable when the
paramters()
method is called - name (str, optional) β Name for the model. Useful for checkpoints when multiple models are optimized.
- model (torch.nn.Module) β The model to be optimized. Should return a non-empty iterable when the
-
register_optimizer
(optimizer, loss_func, retain_graph=False, name='', n_steps=1)ΒΆ Add an optimizer to the solver
Parameters: - optimizer (Optimizer) β A function used for updating model weights during training
- loss_func (LossFunction) β A function (or callable object) that, given the current batch passed by the Solverβs data loader, returns a dictionary containing either a dictionary mapping loss function names to arguments, or a dictionary mapping loss function names to the loss.
- retain_graph (bool) β
True
if the computational graph should be retained after calling the loss function associated to the optimizer. This is usually not needed. - name (str) β Optimizer name. Useful for logging
- n_steps (int) β Number of consecutive steps the optimizer is exectued. Usually set to 1.
-
remove_loss
(name)ΒΆ
-
class
salad.solver.base.
StructuredInit
(**kwargs)ΒΆ Bases:
object
Structured Initialization of Solvers
Initializes the components of a solver and passes arguments. Initialization is done in the following order:
_init_models
_init_losses
_init_optims
Parameters: - kwargs (Keyword arguments) β Pass arguments for all initialization functions. Keyword arguments are passed through the functions in the order specified above. Unused keyword arguments will be printed afterwards. In general, solvers should be designed in a way that ensure that all keyword argumnets are used.
- note (.) β Donβt instantiate or subclass this class directly.
salad.solver.da.base moduleΒΆ
Solver classes for domain adaptation experiments
-
class
salad.solver.da.base.
BaselineDASolver
(*args, **kwargs)ΒΆ Bases:
salad.solver.da.base.DABaseSolver
A domain adaptation solver that actually does not run any adaptation algorithm
This is useful to establish baseline results for the case of no adaptation, for measurement of the domain shift between datasets.
-
class
salad.solver.da.base.
DABaseSolver
(*args, **kwargs)ΒΆ Bases:
salad.solver.classification.BaseClassSolver
Base Class for Unsupervised Domain Adaptation Approaches
Unsupervised DA assumes the presence of a single source domain \(\mathcal S\) along with a target domain \(\mathcal T\) known at training time. Given a labeled sample of points drawn from \(\mathcal S\), \(\{x^s_i, y^s_i\}_{i}^{N_s}\), and an unlabeled sample drawn from \(\mathcal T\), \(\{x^t_i\}_{i}^{N_t}\), unsupervised adaptation aims at minimizing the
\[\min_\theta \mathcal{R}^l_{\mathcal S} (\theta) + \lambda \mathcal{R}^u_{\mathcal {S \times T}} (\theta),\]leveraging an unsupervised risk term \(\mathcal{R}^u_{\mathcal {S \times T}} (\theta)\) that depends on feature representations \(f_\theta(x^s,s)\) and \(f_\theta(x^t,t)\), classifier labels \(h_\theta(x^s,s), h_\theta(x^t,t)\) as well as source labels \(y^s\). The full model \(h = g \circ f\) is a composition of a feature extractor \(f\) and classifier \(g\), both of which can possibly depend on the domain label \(s\) or \(t\) for domain-specific computations.
Notes
This solver adds two accuracies with keys
acc_s
andacc_t
for the source and target domain, respectively. Make sure to include derivation of these accuracy in your loss computation.
-
class
salad.solver.da.base.
DABaselineLoss
(solver)ΒΆ Bases:
object
-
class
salad.solver.da.base.
DATeacher
(model, teacher, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.base.Solver
Base Class for Unsupervised Domain Adaptation Approaches using a teacher model
-
class
salad.solver.da.base.
DGBaseSolver
(*args, **kwargs)ΒΆ Bases:
salad.solver.classification.BaseClassSolver
Base Class for Domain Generalization Approaches
Domain generalization assumes the presence of multiple source domains alongside a target domain unknown at training time. Following cite{Shankar2018}, this setting requires a dataset of training examples \(\{x_i, y_i, d_i\}_{i}^{N}\) with class and domain labels. Importantly, the domains present at training time should reflect the kind of variability that can be expected during inference. The ERM problem is then approached as
\[\min_\theta \sum_d \mathcal{R}^l_{\mathcal S_d} (\theta) = \sum_d \lambda_d \mathbb{E}_{x,y \sim \mathcal S_d }[\ell ( f_\theta(x), h_\theta(x), y, d) ].\]In contrast to the unsupervised setting, samples are now presented in a single batch comprised of inputs \(x\), labels \(y\) and domains \(d\). In a addition to a feature extractor \(f_\theta\) and classifier \(g_\theta\), models should also provide a feature extractor \(f^d_\theta\) to derive domain features along with a domain classifier \(g^d_\theta\), with possibly shared parameters.
In contrast to unsupervised DA, this training setting leverages information from multiple labeled source domains with the goal of generalizing well on data from a previously unseen domain during test time.
salad.solver.da.advdrop moduleΒΆ
-
class
salad.solver.da.advdrop.
AdversarialDropoutLoss
(model, step=1)ΒΆ Bases:
object
Loss Derivation for Adversarial Dropout Regularization
See also
salad.solver.AdversarialDropoutSolver
-
step1
(batch)ΒΆ
-
step2
(batch)ΒΆ
-
step3
(batch)ΒΆ
-
-
class
salad.solver.da.advdrop.
AdversarialDropoutSolver
(model, dataset, **kwargs)ΒΆ Bases:
salad.solver.da.base.DABaseSolver
Implementation of βAdversarial Dropout Regularizationβ
Adversarial Dropout Regulariation [1] estimates uncertainties about the classification process by sampling different models using dropout. On the source domain, a standard cross entropy loss is employed. On the target domain, two predictions are sampled from the model.
Both network parts are jointly trained on the source domain using the standard cross entropy loss,
..math:
\min_{C, G} - \sum_k p^s_k \log y^s_k
The classifier part of the network is trained to maximize the symmetric KL distance between two predictions. This distance is one option for measuring uncertainty in a network. In other words, the classifier aims at maximizing uncertainty given two noisy estimates of the current feature vector.
..math:
\min_{C} - \sum_k p^s_k \log y^s_k + \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}
In contrast, the feature extrator aims at minimizing the uncertainty between two noisy samples given a fixed target classifier.
..math:
\min_{G} \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}
References
[1] Adversarial Dropout Regularization, Saito et al., ICLR 2018
-
class
salad.solver.da.advdrop.
SymmetricKL
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(x, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
salad.solver.da.advdrop.
pack
(*args)ΒΆ
-
salad.solver.da.advdrop.
unpack
(arg, n_tensors)ΒΆ
salad.solver.da.association moduleΒΆ
Associative Domain Adaptation
[Hausser et al., CVPR 2017](#)
-
class
salad.solver.da.association.
AssociationLoss
(model)ΒΆ Bases:
object
Loss function for associative domain adaptation
Given a model, derive a function that computes arguments for the association loss.
-
class
salad.solver.da.association.
AssociativeSolver
(model, dataset, learningrate, walker_weight=1.0, visit_weight=0.1, *args, **kwargs)ΒΆ Bases:
salad.solver.da.base.DABaseSolver
Implementation of βAssociative Domain Adaptationβ
Associative Domain Adaptation [1] leverages a random walk based on feature similarity as a distance between source and target feature correlations. The algorithm is based on two loss functions that are added to the standard cross entropy loss on the source domain.
Given features for source and target domain, a kernel function is used to measure similiarity between both domains. The original implementation uses the scalar product between feature vectors, scaled by an exponential,
\[K_{ij} = k(x^s_i, x^t_j) = \exp(\langle x^s_i, x^t_j \rangle)\]This kernel is then used to compute transition probabilities
\[p(x^t_j | x^s_i) = \frac{K_{ij}}{\sum_{l} K_{lj}}\]and
\[p(x^s_k | x^t_j) = \frac{K_{jk}}{\sum_{l} K_{kl}}\]to compute the roundtrip
\[p(x^s_k | x^s_i) = \sum_{j} p(x^s_k | x^t_j) p(x^t_j | x^s_i)\]It is then required that
- WalkerLoss The roundtrip ends at a sample with the same class label, i.e., $y^s_i = y^s_k$
- VisitLoss Each target sample is visited with a certain probability
As one possible modification, different kernel functions could be used to measure similarity between the domains. With this solver, it is advised to use large sample sizes for the target domain and ensure that a sufficient number of source samples is available for each batch.
TODO: Possibly in the solver class, implement a functionality to aggregate batches to avoid memory issues.
Parameters: - model (torch.nn.Module) β A pytorch model to be trained by association
- dataset (StackedDataset) β A dataset suitable for an unsupervised solver
- learningrate (int) β TODO
References
[1] Associative Domain Adaptation, HΓ€usser et al., CVPR 2017, https://arxiv.org/abs/1708.00938
salad.solver.da.coral moduleΒΆ
Losses for Correlatin Alignment
Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: https://arxiv.org/pdf/1607.01719.pdf
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z
-
class
salad.solver.da.coral.
CentroidDistanceLossSolver
(model, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z and: https://arxiv.org/pdf/1705.08180.pdf
-
class
salad.solver.da.coral.
CentroidLoss
(model)ΒΆ Bases:
object
-
class
salad.solver.da.coral.
CorrDistanceSolver
(model, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z
-
class
salad.solver.da.coral.
CorrelationDistanceLoss
(model, n_steps_recompute=10, nullspace=False)ΒΆ Bases:
object
-
class
salad.solver.da.coral.
CorrelationDistanceSolver
(model, dataset, *args, **kwargs)ΒΆ
-
class
salad.solver.da.coral.
DeepCoralSolver
(model, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: [https://arxiv.org/pdf/1607.01719.pdf](https://arxiv.org/pdf/1607.01719.pdf)
Loss Functions:
\[\mathcal{L}(x^s, x^t) = \frac{1}{4d^2} \| C_s - C_t \|\]
-
class
salad.solver.da.coral.
DeepLogCoralSolver
(model, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z
salad.solver.da.crossgrad moduleΒΆ
Cross Gradient Training
ICLR 2018
-
class
salad.solver.da.crossgrad.
CrossGradLoss
(solver)ΒΆ Bases:
object
Cross Gradient Training
http://arxiv.org/abs/1804.10745
-
pertub
(x, loss, eps=1e-05)ΒΆ
-
-
class
salad.solver.da.crossgrad.
CrossGradSolver
(model, *args, **kwargs)ΒΆ Bases:
salad.solver.da.base.DGBaseSolver
Cross Gradient Optimizer
A domain generalization solver based on Cross Gradient Training [1].
- ..math:
- p(y | x) = int_d p(y|x,d) p(d|x) dd
- ..math:
- x_d = x + eps Nabla_y L(y) \ x_y = x + eps Nabla_d L(d)
References
[1] (1, 2) Shankar et al., Generalizing Across Domains via Cross-Gradient Training, ICLR 2018
-
class
salad.solver.da.crossgrad.
Model
(n_classes, n_domains)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(x)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_domain
(x)ΒΆ
-
parameters_classifier
()ΒΆ
-
parameters_domain
()ΒΆ
-
-
class
salad.solver.da.crossgrad.
MultiDomainModule
(n_domains)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(x)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_domain
(x)ΒΆ
-
parameters_classifier
()ΒΆ
-
parameters_domain
()ΒΆ
-
-
salad.solver.da.crossgrad.
concat
(x, z)ΒΆ Concat 4D tensor with expanded 2D tensor
-
salad.solver.da.crossgrad.
conv2d
(m, n, k, act=True)ΒΆ
-
salad.solver.da.crossgrad.
features
(inp)ΒΆ
-
salad.solver.da.crossgrad.
get_dataset
(noisemodels, batch_size, shuffle=True, num_workers=0, which='train')ΒΆ
salad.solver.da.dann moduleΒΆ
-
class
salad.solver.da.dann.
AdversarialLoss
(G, D, train_G=True)ΒΆ Bases:
object
-
class
salad.solver.da.dann.
DANNSolver
(model, discriminator, dataset, learningrate, *args, **kwargs)ΒΆ Bases:
salad.solver.da.base.DABaseSolver
Domain Adversarial Neural Networks Solver
This builds upon the normal classification solver that uses CrossEntropy or BinaryCrossEntropy for optimizing neural networks.
salad.solver.da.dirtt moduleΒΆ
-
class
salad.solver.da.dirtt.
DIRTT
(model, teacher)ΒΆ Bases:
object
-
class
salad.solver.da.dirtt.
DIRTTSolver
(model, teacher, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.base.Solver
Virtual Adversarial Domain Adaptation
-
class
salad.solver.da.dirtt.
VADA
(G, D, train_G=True)ΒΆ
-
class
salad.solver.da.dirtt.
VADASolver
(model, discriminator, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.da.dann.DANNSolver
Virtual Adversarial Domain Adaptation
salad.solver.da.dirtt_re moduleΒΆ
Self Ensembling for Visual Domain Adaptation
-
class
salad.solver.da.dirtt_re.
DIRTT
(model, teacher)ΒΆ Bases:
object
-
class
salad.solver.da.dirtt_re.
DIRTTSolver
(model, teacher, dataset, learningrate, *args, **kwargs)ΒΆ
salad.solver.da.djdot moduleΒΆ
-
class
salad.solver.da.djdot.
DJDOTSolver
(model, dataset, *args, **kwargs)ΒΆ Bases:
salad.solver.classification.BaseClassSolver
Deep Joint Optimal Transport solver
TODO
-
derive_losses
(batch)ΒΆ
-
salad.solver.da.ensembling moduleΒΆ
Self Ensembling for Visual Domain Adaptation
-
class
salad.solver.da.ensembling.
EnsemblingLoss
(model, teacher)ΒΆ Bases:
object
-
class
salad.solver.da.ensembling.
SelfEnsemblingSolver
(model, teacher, dataset, learningrate, *args, **kwargs)ΒΆ
salad.solver.classification moduleΒΆ
-
class
salad.solver.classification.
BCESolver
(*args, **kwargs)ΒΆ Bases:
salad.solver.classification.BaseClassSolver
Solver for a classification experiment
-
class
salad.solver.classification.
BaseClassSolver
(model, dataset, multiclass=True, *args, **kwargs)ΒΆ Bases:
salad.solver.base.Solver
Base Solver for classification experiments
Parameters: - model (nn.Module) β A model to train on a classification target
- dataset (torch.utils.data.Dataset) β The dataset providing training samples
- multiclass (bool) β If True,
CrossEntropyLoss
is used,BCEWithLogitsLoss
otherwise.
-
class
salad.solver.classification.
ClassificationLoss
(solver)ΒΆ Bases:
object
-
class
salad.solver.classification.
FinetuneSolver
(*args, **kwargs)ΒΆ Bases:
salad.solver.classification.BaseClassSolver
Finetune a pre-trained deep learning models
Given a model with separable feature extractor and classifier, use different learning rates and regularization settings. Useful for fine-tuning pre-trained ImageNet models or finetuning of saved model checkpoints
Parameters: - model (nn.Module) β Module with two separate parts
- dataset (Dataset) β The dataset used for training
-
class
salad.solver.classification.
MultiDomainClassificationLoss
(solver, domain)ΒΆ Bases:
object
-
class
salad.solver.classification.
MultidomainBCESolver
(model, dataset, learningrate, multiclass=True, loss_weights=None, *args, **kwargs)ΒΆ Bases:
salad.solver.base.Solver
salad.solver.gan moduleΒΆ
Tools for training Generative Adversarial Networks (GANs)
The class is primarily used to train conditional networks (CGANs).
Notes
Contributions for extensions wanted!
-
class
salad.solver.gan.
CGANLoss
(solver, G, Ds, train_G)ΒΆ Bases:
object
Loss Derivation for a Conditional GAN
-
derive_losses
(batch)ΒΆ
-
-
class
salad.solver.gan.
ConditionalGANSolver
(model, dataset, learningrate=0.0002, *args, **kwargs)ΒΆ Bases:
salad.solver.base.Solver
-
format_train_report
(losses)ΒΆ
-
n_classes
= [1, 11, 3]ΒΆ
-
names
= ['D_GAN', 'D_CL', 'D_CON']ΒΆ
-
-
class
salad.solver.gan.
GANSolver
(dataset, n_epochs=1, savedir='./log', gpu=None, dryrun=False, **kwargs)ΒΆ Bases:
salad.solver.base.Solver