Skip to content

frameworks.clip.services.clip_utils

clip_utils

Training a CLIP like dual encoder models using text and vision encoders in the library.

The script can be used to train CLIP like models for languages other than English by using a text encoder pre-trained in the desired language. Currently this script supports the following vision and text models: Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip) Text models: BERT, ROBERTa (https://huggingface.co/models?filter=fill-mask)

Classes:

Name Description
ModelArguments

Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.

DataTrainingArguments

Arguments pertaining to what data we are going to input our model for training and eval.

Functions:

Name Description
collate_fn
parse_args
set_logging
detect_last_checkpoint
load_and_prepare_dataset
initialize_model_components
resolve_column_names
build_transforms
prepare_split_dataset
preprocess_datasets
train_and_evaluate
main

Attributes:

Name Type Description
dataset_name_mapping

dataset_name_mapping = {'image_caption_dataset.py': ('image_path', 'caption')} module-attribute

ModelArguments(model_name_or_path, config_name=None, tokenizer_name=None, feature_extractor_name=None, cache_dir=None, model_revision='main', use_fast_tokenizer=True, use_auth_token=False, freeze_vision_model=False, freeze_text_model=False) dataclass

Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.

Attributes:

Name Type Description
model_name_or_path str
config_name str | None
tokenizer_name str | None
feature_extractor_name str | None
cache_dir str | None
model_revision str
use_fast_tokenizer bool
use_auth_token bool
freeze_vision_model bool
freeze_text_model bool

model_name_or_path = field(metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}) class-attribute instance-attribute

config_name = field(default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}) class-attribute instance-attribute

tokenizer_name = field(default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}) class-attribute instance-attribute

feature_extractor_name = field(default=None, metadata={'help': 'Name or path of preprocessor config.'}) class-attribute instance-attribute

cache_dir = field(default=None, metadata={'help': 'Where do you want to store the pretrained models downloaded from s3'}) class-attribute instance-attribute

model_revision = field(default='main', metadata={'help': 'The specific model version to use (can be a branch name, tag name or commit id).'}) class-attribute instance-attribute

use_fast_tokenizer = field(default=True, metadata={'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'}) class-attribute instance-attribute

use_auth_token = field(default=False, metadata={'help': 'Will use the token generated when running `huggingface-cli login` (necessary to use this script with private models).'}) class-attribute instance-attribute

freeze_vision_model = field(default=False, metadata={'help': 'Whether to freeze the vision model parameters or not.'}) class-attribute instance-attribute

freeze_text_model = field(default=False, metadata={'help': 'Whether to freeze the text model parameters or not.'}) class-attribute instance-attribute

DataTrainingArguments(dataset_name=None, dataset_config_name=None, data_dir=None, image_column='image_path', caption_column='caption', train_file=None, validation_file=None, test_file=None, max_seq_length=128, max_train_samples=None, max_eval_samples=None, overwrite_cache=False, preprocessing_num_workers=None) dataclass

Arguments pertaining to what data we are going to input our model for training and eval.

Attributes:

Name Type Description
dataset_name str | None
dataset_config_name str | None
data_dir str | None
image_column str | None
caption_column str | None
train_file str | None
validation_file str | None
test_file str | None
max_seq_length int | None
max_train_samples int | None
max_eval_samples int | None
overwrite_cache bool
preprocessing_num_workers int | None

dataset_name = field(default=None, metadata={'help': 'The name of the dataset to use (via the datasets library).'}) class-attribute instance-attribute

dataset_config_name = field(default=None, metadata={'help': 'The configuration name of the dataset to use (via the datasets library).'}) class-attribute instance-attribute

data_dir = field(default=None, metadata={'help': 'The data directory containing input files.'}) class-attribute instance-attribute

image_column = field(default='image_path', metadata={'help': 'The name of the column in the datasets containing the full image file paths.'}) class-attribute instance-attribute

caption_column = field(default='caption', metadata={'help': 'The name of the column in the datasets containing the image captions.'}) class-attribute instance-attribute

train_file = field(default=None, metadata={'help': 'The input training data file (a jsonlines file).'}) class-attribute instance-attribute

validation_file = field(default=None, metadata={'help': 'An optional input evaluation data file (a jsonlines file).'}) class-attribute instance-attribute

test_file = field(default=None, metadata={'help': 'An optional input test data file (a jsonlines file).'}) class-attribute instance-attribute

max_seq_length = field(default=128, metadata={'help': 'The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.'}) class-attribute instance-attribute

max_train_samples = field(default=None, metadata={'help': 'For debugging purposes or quicker training, truncate the number of training examples to this value if set.'}) class-attribute instance-attribute

max_eval_samples = field(default=None, metadata={'help': 'For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set.'}) class-attribute instance-attribute

overwrite_cache = field(default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}) class-attribute instance-attribute

preprocessing_num_workers = field(default=None, metadata={'help': 'The number of processes to use for the preprocessing.'}) class-attribute instance-attribute

collate_fn(examples)

parse_args()

set_logging(training_args)

detect_last_checkpoint(training_args)

load_and_prepare_dataset(model_args, data_args, training_args)

initialize_model_components(model_args)

resolve_column_names(data_args, column_names)

build_transforms(config, feature_extractor)

prepare_split_dataset(dataset, split, args, tokenizer, transform, image_column, caption_column, column_names)

preprocess_datasets(dataset, data_args, training_args, tokenizer, feature_extractor, image_column, caption_column, column_names, config)

train_and_evaluate(trainer, training_args, last_checkpoint, model_args, data_args)

main()