diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt index bf14291f53..3827f0852a 100644 --- a/examples/research_projects/realfill/requirements.txt +++ b/examples/research_projects/realfill/requirements.txt @@ -3,7 +3,7 @@ accelerate==0.23.0 transformers==4.34.0 peft==0.5.0 torch==2.0.1 -torchvision==0.15.2 +torchvision>=0.16 ftfy==6.1.1 tensorboard==2.14.0 Jinja2==3.1.2 diff --git a/examples/research_projects/realfill/train_realfill.py b/examples/research_projects/realfill/train_realfill.py index 9d00a21b1a..1549d81305 100644 --- a/examples/research_projects/realfill/train_realfill.py +++ b/examples/research_projects/realfill/train_realfill.py @@ -450,10 +450,10 @@ class RealFillDataset(Dataset): self.transform = transforms_v2.Compose( [ + transforms_v2.ToImage(), transforms_v2.RandomResize(size, int(1.125 * size)), transforms_v2.RandomCrop(size), - transforms_v2.ToImageTensor(), - transforms_v2.ConvertImageDtype(), + transforms_v2.ToDtype(torch.float32, scale=True), transforms_v2.Normalize([0.5], [0.5]), ] )