<< Back to posts
Python script to remove image backgrounds for free
Here is a quick script to remove the background of all images in a folder for free. This works on both Mac and Linux computers.
Install
pip install torch torchvision pillow numpy typing scikit-image huggingface_hub transformers>=4.39.1 tqdm
Usage
If our input images are in a folder at the path PATH_TO_INPUT_DIR
and we want to save the non-background versions of these images to PATH_TO_OUTPUT_DIR
, then run:
python bg_remover.py PATH_TO_INPUT_DIR --path_to_output_dir PATH_TO_OUTPUT_DIR
Script
Save this in a file called bg_remover.py
:
from transformers import pipeline
import argparse
import torch
import os
import string
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="Remove background of images in folder")
parser.add_argument("path_to_input_dir", type=str, help="Path to folder containing images")
parser.add_argument("--path_to_output_dir", type=str, default=None, help="Path to folder containing images")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
path_to_input_dir: str = args.path_to_input_dir
path_to_output_dir: str = args.path_to_output_dir if args.path_to_output_dir else f"{path_to_input_dir}_bg_removed"
os.makedirs(path_to_output_dir, exist_ok=True)
# Load the model
device: str = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True, device=device)
printable = set(string.printable)
print("--------")
print("Using device:", device)
print("Loading images from:", path_to_input_dir)
print("Saving images to:", path_to_output_dir)
print("--------")
for file in tqdm(os.listdir(path_to_input_dir), desc='Looping through input directory', total=len(os.listdir(path_to_input_dir))):
if not file.endswith(('.jpg', '.jpeg', '.png')):
# Skip non-image files
continue
path_to_img: str = os.path.join(path_to_input_dir, file)
pillow_image = pipe(path_to_img) # applies mask on input and returns a pillow image
file_name: str = ''.join(filter(lambda x: x in printable, file))
pillow_image.save(os.path.join(path_to_output_dir, file_name))
print("Done!")