Train
first_break_picking.train_eval.train module
- first_break_picking.train_eval.train.train(base_dir: str, batch_size: int, val_percentage: float, epochs: int, learning_rate: float, device: device, path_to_save: str, save_frequency: int, upsampled_size_row: int, upsampled_size_col: int, type_of_problem: str = 'fb', loss_fn_name: str = 'ce', model_name: str = 'unet_resnet', checkpoint_path: str | None = None, features: List = [16, 32, 64, 128], in_channels: int = 1, out_channels: int = 2, encoder_weight: str = 'imagenet', step_size_milestone: int = None, show: bool = False) None
This function is the main function to be calld for training
- Parameters:
base_dir (str) – Directory of data
height_model (int) – Number of time samples
batch_size (int) – Batch size
val_percentage (float) – Fraction of validation
epochs (int) – Number of epochs
learning_rate (float) – Learning rate
upsampled_size (int) – Size of each subshot during training
device (torch.device) – Device
path_to_save (str) – Path to save the checkpoints
save_frequency (int) – Frequency of saving checkpoints
band_size (int, optional) – Size of the band if we consider a band on the first break, by default 0
strip_weight (float, optional) – Weight of loss for band if we consider a band on the first break, by default 0.0
loss_fn_name (str, optional) – Name of loss function, by default “ce”
model_name (str, optional) – Name of desired network, by default “unet_resnet”
checkpoint_path (Optional[str], optional) – Checkpointt address for loading, by default None
features (List, optional) – Number of channels in each conv layer, by default [16, 32, 64, 128]
n_channels (int, optional) – Number of input channels in iput shot, by default 1
n_classes (int, optional) – Number of output channels, by default 2
encoder_weight (str, by default imagenet) – Name of the weigths for initializing the network
step_size_milestone (int, None) – Step size will be divided by 10 at every step_size_milestone, by default None which leads to a constant step size
show (bool, optional) – If you need to show some sampels after training, by default False