<< Back to posts

Python script to remove image backgrounds for free

Posted on June 15, 2024 • Tags: python huggingface transformers images background removal

Here is a quick script to remove the background of all images in a folder for free. This works on both Mac and Linux computers.

Install

pip install torch torchvision pillow numpy typing scikit-image huggingface_hub transformers>=4.39.1 tqdm

Usage

If our input images are in a folder at the path PATH_TO_INPUT_DIR and we want to save the non-background versions of these images to PATH_TO_OUTPUT_DIR, then run:

python bg_remover.py PATH_TO_INPUT_DIR --path_to_output_dir PATH_TO_OUTPUT_DIR

Script

Save this in a file called bg_remover.py:

from transformers import pipeline
import argparse
import torch
import os
import string
from tqdm import tqdm

def parse_args():
  parser = argparse.ArgumentParser(description="Remove background of images in folder")
  parser.add_argument("path_to_input_dir", type=str, help="Path to folder containing images")
  parser.add_argument("--path_to_output_dir", type=str, default=None, help="Path to folder containing images")
  return parser.parse_args()

if __name__ == '__main__':
  args = parse_args()
  path_to_input_dir: str = args.path_to_input_dir
  path_to_output_dir: str = args.path_to_output_dir if args.path_to_output_dir else f"{path_to_input_dir}_bg_removed"
  os.makedirs(path_to_output_dir, exist_ok=True)

  # Load the model
  device: str = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
  pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True, device=device)
  printable = set(string.printable)
  
  print("--------")
  print("Using device:", device)
  print("Loading images from:", path_to_input_dir)
  print("Saving images to:", path_to_output_dir)
  print("--------")

  for file in tqdm(os.listdir(path_to_input_dir), desc='Looping through input directory', total=len(os.listdir(path_to_input_dir))):
    if not file.endswith(('.jpg', '.jpeg', '.png')):
      # Skip non-image files
      continue
    path_to_img: str = os.path.join(path_to_input_dir, file)
    pillow_image = pipe(path_to_img) # applies mask on input and returns a pillow image
    file_name: str = ''.join(filter(lambda x: x in printable, file))
    pillow_image.save(os.path.join(path_to_output_dir, file_name))

  print("Done!")

References