|
|
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class ModelConfig(PretrainedConfig): |
|
|
model_type = "SongFormer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim=2048, |
|
|
input_dim_raw=4096, |
|
|
transformer_encoder_input_dim=1024, |
|
|
transformer_input_dim=512, |
|
|
num_transformer_layers=4, |
|
|
transformer_nhead=8, |
|
|
transformer_dropout=0.1, |
|
|
num_classes=128, |
|
|
num_dataset_classes=64, |
|
|
down_sample_conv_kernel_size=3, |
|
|
down_sample_conv_stride=3, |
|
|
down_sample_conv_dropout=0.1, |
|
|
down_sample_conv_padding=0, |
|
|
boundary_tv_loss_beta=0.6, |
|
|
boundary_tv_loss_lambda=0.4, |
|
|
boundary_tv_loss_boundary_threshold=0.01, |
|
|
boundary_tv_loss_reduction_weight=0.1, |
|
|
boundary_tvloss_weight=0.05, |
|
|
label_focal_loss_alpha=0.25, |
|
|
label_focal_loss_gamma=2.0, |
|
|
label_focal_loss_weight=0.2, |
|
|
loss_weight_section=0.2, |
|
|
loss_weight_function=0.8, |
|
|
learn_label=True, |
|
|
learn_segment=True, |
|
|
local_maxima_filter_size=3, |
|
|
frame_rates=8.333, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.input_dim = input_dim |
|
|
self.input_dim_raw = input_dim_raw |
|
|
self.transformer_encoder_input_dim = transformer_encoder_input_dim |
|
|
self.transformer_input_dim = transformer_input_dim |
|
|
self.num_transformer_layers = num_transformer_layers |
|
|
self.transformer_nhead = transformer_nhead |
|
|
self.transformer_dropout = transformer_dropout |
|
|
self.num_classes = num_classes |
|
|
self.num_dataset_classes = num_dataset_classes |
|
|
self.down_sample_conv_kernel_size = down_sample_conv_kernel_size |
|
|
self.down_sample_conv_stride = down_sample_conv_stride |
|
|
self.down_sample_conv_dropout = down_sample_conv_dropout |
|
|
self.down_sample_conv_padding = down_sample_conv_padding |
|
|
self.boundary_tv_loss_beta = boundary_tv_loss_beta |
|
|
self.boundary_tv_loss_lambda = boundary_tv_loss_lambda |
|
|
self.boundary_tv_loss_boundary_threshold = boundary_tv_loss_boundary_threshold |
|
|
self.boundary_tv_loss_reduction_weight = boundary_tv_loss_reduction_weight |
|
|
self.boundary_tvloss_weight = boundary_tvloss_weight |
|
|
self.label_focal_loss_alpha = label_focal_loss_alpha |
|
|
self.label_focal_loss_gamma = label_focal_loss_gamma |
|
|
self.label_focal_loss_weight = label_focal_loss_weight |
|
|
self.loss_weight_section = loss_weight_section |
|
|
self.loss_weight_function = loss_weight_function |
|
|
self.learn_label = learn_label |
|
|
self.learn_segment = learn_segment |
|
|
self.local_maxima_filter_size = local_maxima_filter_size |
|
|
self.frame_rates = frame_rates |