from fastai.text.all import *
from fastai.vision.all import *
# !pip install albumentations
source = untar_data(URLs.PETS)
items = get_image_files(source/"images")
class AlbumentationsTransform(Transform):
def __init__(self, aug): self.aug = aug
def encodes(self, img: PILImage):
aug_img = self.aug(image=np.array(img))['image']
return PILImage.create(aug_img)
tfm = AlbumentationsTransform(ShiftScaleRotate(p=1))
a,b = tfm((img, 'dog'))
show_image(a, title=b);
cv_source = untar_data(URLs.CAMVID_TINY)
cv_items = get_image_files(cv_source/'images')
cv_splitter = RandomSplitter(seed=42)
cv_split = cv_splitter(cv_items)
cv_label = lambda o: cv_source/'labels'/f'{o.stem}_P{o.suffix}'
class ImageResizer(Transform):
order=1
"Resize image to `size` using `resample`"
def __init__(self, size, resample=BILINEAR):
if not is_listy(size): size=(size,size)
self.size,self.resample = (size[1],size[0]),resample
def encodes(self, o:PILImage): return o.resize(size=self.size, resample=self.resample)
def encodes(self, o:PILMask): return o.resize(size=self.size, resample=NEAREST)
tfms = [[PILImage.create], [cv_label, PILMask.create]]
cv_dsets = Datasets(cv_items, tfms, splits=cv_split)
dls = cv_dsets.dataloaders(bs=64, after_item=[ImageResizer(128), ToTensor(), IntToFloatTensor()])
# split_idx=0 (0 is for the training set, 1 for the validation set):
class SegmentationAlbumentationsTransform(ItemTransform):
split_idx = 0
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
cv_dsets = Datasets(cv_items, tfms, splits=cv_split)
dls = cv_dsets.dataloaders(bs=64, after_item=[ImageResizer(128), ToTensor(), IntToFloatTensor(),
SegmentationAlbumentationsTransform(ShiftScaleRotate(p=1))])
dls.show_batch(max_n=4)
class AlbumentationsTransform(RandTransform):
"A transform handler for multiple `Albumentation` transforms"
split_idx,order=None,2
def __init__(self, train_aug, valid_aug): store_attr()
def before_call(self, b, split_idx):
self.idx = split_idx
def encodes(self, img: PILImage):
if self.idx == 0:
aug_img = self.train_aug(image=np.array(img))['image']
else:
aug_img = self.valid_aug(image=np.array(img))['image']
return PILImage.create(aug_img)
We changed our split_idx to be None, which allows for us to say when we’re setting our split_idx.
We also inherit from RandTransform, which allows for us to set that split_idx in our before_call.