def run_safety_checker(image):
safety_checker_input = feature_extractor(image, return_tensors="pt").to(
"cuda"
)
np_image = [np.array(val) for val in image]
image, has_nsfw_concept = safety_checker(
images=np_image,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
return image, has_nsfw_concept
def predict(
prompt: str = "A superhero smiling",
negative_prompt: str = "worst quality, low quality",
width: int = 1024,
height: int = 1024,
num_outputs: int = 1,
scheduler: str = "K_EULER",
num_inference_steps: int = 4,
guidance_scale: float = 0,
seed: int = None,
disable_safety_checker: bool = False,
):
"""Run a single prediction on the model"""
global pipe
if seed is None:
seed = int.from_bytes(os.urandom(4), "big")
print(f"Using seed: {seed}")
generator = torch.Generator("cuda").manual_seed(seed)
# OOMs can leave vae in bad state
if pipe.vae.dtype == torch.float32:
pipe.vae.to(dtype=torch.float16)
sdxl_kwargs = {}
print(f"Prompt: {prompt}")
sdxl_kwargs["width"] = width
sdxl_kwargs["height"] = height
pipe.scheduler = SCHEDULERS[scheduler].from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
common_args = {
"prompt": [prompt] * num_outputs,
"negative_prompt": [negative_prompt] * num_outputs,
"guidance_scale": guidance_scale,
"generator": generator,
"num_inference_steps": num_inference_steps,
}
output = pipe(**common_args, **sdxl_kwargs)
if not disable_safety_checker:
_, has_nsfw_content = run_safety_checker(output.images)
output_paths = []
for i, image in enumerate(output.images):
if not disable_safety_checker:
if has_nsfw_content[i]:
print(f"NSFW content detected in image {i}")
continue
output_path = f"/tmp/out-{i}.png"
image.save(output_path)
output_paths.append(Path(output_path))
if len(output_paths) == 0:
raise Exception(
"NSFW content detected. Try running it again, or try a different prompt."
)
return output_paths