RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float8_e4m3fn
import torch
from diffusers.utils import load_image
use local files for this moment
from pipeline_flux_controlnet import FluxControlNetPipeline
from controlnet_flux import FluxControlNetModel
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union_fp8 = 'ABDALLALSWAITI/FLUX.1-dev-ControlNet-Union-Pro-2.0-fp8'
Load using FP8 data type
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union_fp8, torch_dtype=torch.float8_e4m3fn)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=[controlnet], torch_dtype=torch.bfloat16) # use [] to enable multi-CNs
pipe.enable_model_cpu_offload()
replace with other conds
control_image = load_image("./conds/canny.png")
width, height = control_image.size
prompt = "A young girl stands gracefully at the edge of a serene beach, her long, flowing hair gently tousled by the sea breeze. She wears a soft, pastel-colored dress that complements the tranquil blues and greens of the coastal scenery. The golden hues of the setting sun cast a warm glow on her face, highlighting her serene expression. The background features a vast, azure ocean with gentle waves lapping at the shore, surrounded by distant cliffs and a clear, cloudless sky. The composition emphasizes the girl's serene presence amidst the natural beauty, with a balanced blend of warm and cool tones."
image = pipe(
prompt,
control_image=[control_image, control_image], # try with different conds such as canny&depth, pose&depth
width=width,
height=height,
controlnet_conditioning_scale=[0.35, 0.35],
control_guidance_end=[0.8, 0.8],
num_inference_steps=30,
guidance_scale=3.5,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
When I execute the above code, an error occurs. How can I solve it?
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float8_e4m3fn
torch version is 2.6.0+cu124.