frameworks.sam2.services.trainer¶
trainer
¶
Classes:
| Name | Description |
|---|---|
Sam2Trainer |
Trainer class for managing the full fine-tuning process of a SAM2 model using a COCO dataset. |
Functions:
| Name | Description |
|---|---|
prepare_directories |
Creates required directories for image and annotation data. |
load_coco_annotations |
Loads COCO-format annotations from a JSON file. |
generate_mask |
Generates a PNG mask from COCO-style polygon annotations. |
convert_coco_to_png_masks |
Converts COCO annotations to PNG masks and organizes them in folders. |
normalize_filenames |
Normalizes filenames in a list of directories to avoid naming conflicts. |
parse_and_log_sam2_output |
Parses SAM2 training output and logs metrics into the Picsellia experiment. |
Sam2Trainer(model, dataset_collection, context, sam2_repo_path)
¶
Trainer class for managing the full fine-tuning process of a SAM2 model using a COCO dataset. This class handles data preparation, training launch, and checkpoint saving.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Model
|
Picsellia model instance containing paths and metadata. |
required |
|
DatasetCollection[CocoDataset]
|
Dataset collection containing the training data. |
required |
|
PicselliaTrainingContext | LocalTrainingContext
|
Training context containing hyperparameters and working directory. |
required |
|
str
|
Path to the local SAM2 repository. |
required |
Methods:
| Name | Description |
|---|---|
prepare_data |
Prepares the training data by converting COCO annotations to PNG masks. |
launch_training |
Launches the SAM2 training process. |
save_checkpoint |
Saves the final model checkpoint as an artifact in the Picsellia experiment. |
Attributes:
| Name | Type | Description |
|---|---|---|
model |
|
|
dataset_collection |
|
|
context |
|
|
sam2_repo_path |
|
|
img_root |
|
|
ann_root |
|
model = model
instance-attribute
¶
dataset_collection = dataset_collection
instance-attribute
¶
context = context
instance-attribute
¶
sam2_repo_path = sam2_repo_path
instance-attribute
¶
img_root = os.path.join(sam2_repo_path, 'data', 'JPEGImages')
instance-attribute
¶
ann_root = os.path.join(sam2_repo_path, 'data', 'Annotations')
instance-attribute
¶
prepare_data()
¶
Prepares the training data by converting COCO annotations to PNG masks.
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
The filename of the pretrained weights. |
prepare_directories(img_root, ann_root)
¶
load_coco_annotations(coco_path)
¶
Loads COCO-format annotations from a JSON file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to the COCO annotations file. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
dict[str, Any]: Parsed JSON dictionary. |
generate_mask(width, height, annotations)
¶
Generates a PNG mask from COCO-style polygon annotations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Width of the mask. |
required |
|
int
|
Height of the mask. |
required |
|
list
|
List of annotation objects. |
required |
Returns:
| Type | Description |
|---|---|
Image
|
Image.Image: The generated mask image. |
convert_coco_to_png_masks(coco, source_images, img_root, ann_root)
¶
Converts COCO annotations to PNG masks and organizes them in folders.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dict
|
Loaded COCO annotations. |
required |
|
str
|
Directory containing original image files. |
required |
|
str
|
Destination directory for images. |
required |
|
str
|
Destination directory for annotations. |
required |
normalize_filenames(root_dirs)
¶
Normalizes filenames in a list of directories to avoid naming conflicts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
list[str]
|
List of directory paths. |
required |
parse_and_log_sam2_output(process, context, log_file_path)
¶
Parses SAM2 training output and logs metrics into the Picsellia experiment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Popen[str]
|
Subprocess running the training script. |
required |
|
PicselliaTrainingContext | LocalTrainingContext
|
Picsellia pipeline context used for logging. |
required |
|
str
|
File to store raw stdout logs from the training process. |
required |