osc.train

Main training methods.

Functions

batches_per_epoch_test

Compute number of batches in one test epoch.

batches_per_epoch_train

Compute number of batches in one training epoch.

batches_per_epoch_val

Compute number of batches in one validation epoch.

build_dataset_test

Build test dataset.

build_dataset_train

Build training dataset.

build_dataset_val

Build validation dataset.

build_dataset_vqa

Build VQA dataset for linear probing.

build_linear_probe

rtype

Module

build_loss_fn_global

Build global loss function.

build_loss_fn_objects

Build object loss function.

build_model

Build model

build_optimizer

Build optimizer for training.

build_scheduler

Build learning rate scheduler for training.

extract_vqa_features

rtype

Tuple[Tensor, Tensor]

get_viz_batch

Prepare a batch of images for visualization.

log_env_info

log_model_parameters

Log model parameters as a table.

main

rtype

None

run_test_linear_probes

run_test_segmentation

run_train_epoch

Run one epoch of training.

run_train_val_viz_epochs

Run train, val, viz for a certain number of epochs.

run_val_epoch

Run one validation epoch.

run_viz

Run inference on a single batch of images and visualize everything!

update_cfg

Classes

ModelLoss

Function type signature for model loss functions.