Instructions to use YiYiXu/quant-block with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use YiYiXu/quant-block with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("YiYiXu/quant-block", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from typing import List, Optional | |
| from diffusers.modular_pipelines import ( | |
| InputParam, | |
| OutputParam, | |
| ModularPipelineBlocks, | |
| PipelineState, | |
| ) | |
| class QuantizationConfigBlock(ModularPipelineBlocks): | |
| """Block to create BitsAndBytes quantization config for model loading.""" | |
| def description(self) -> str: | |
| return "Creates a BitsAndBytes quantization config for loading models with reduced precision" | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| # Target component | |
| InputParam( | |
| "component", | |
| type_hint=str, | |
| default="transformer", | |
| description="Component name to apply quantization to", | |
| metadata={"mellon": "dropdown"} | |
| ), | |
| # Bits selection | |
| InputParam( | |
| "quant_type", | |
| type_hint=str, | |
| default="bnb_4bit", | |
| description="Quantization backend Type", | |
| metadata={"mellon": "dropdown"}, # "options": ["bnb_4bit", "bnb_8bit"] | |
| ), | |
| # ===== 4-bit options ===== | |
| InputParam( | |
| "bnb_4bit_quant_type", | |
| type_hint=str, | |
| default="nf4", | |
| description="4-bit quantization type", | |
| metadata={"mellon": "dropdown"}, # "options": ["nf4", "fp4"] | |
| ), | |
| InputParam( | |
| "bnb_4bit_compute_dtype", | |
| type_hint=Optional[str], | |
| description="Compute dtype for 4-bit quantization", | |
| metadata={"mellon": "dropdown"}, # "options": ["", "float32", "float16", "bfloat16"] | |
| ), | |
| InputParam( | |
| "bnb_4bit_use_double_quant", | |
| type_hint=bool, | |
| default=False, | |
| description="Use nested quantization (quantize the quantization constants)", | |
| metadata={"mellon": "checkbox"} | |
| ), | |
| # ===== 8-bit options ===== | |
| InputParam( | |
| "llm_int8_threshold", | |
| type_hint=float, | |
| default=6.0, | |
| description="Outlier threshold for 8-bit quantization (values above this use fp16)", | |
| metadata={"mellon": "slider"}, | |
| ), | |
| InputParam( | |
| "llm_int8_has_fp16_weight", | |
| type_hint=bool, | |
| default=False, | |
| description="Keep weights in fp16 for 8-bit (useful for fine-tuning)", | |
| metadata={"mellon": "checkbox"}, | |
| ), | |
| InputParam( | |
| "llm_int8_skip_modules", | |
| type_hint=Optional[List[str]], | |
| ), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "quantization_config", | |
| type_hint=dict, | |
| description="Quantization config dict for load_components", | |
| ), | |
| ] | |
| def __call__(self, pipeline, state: PipelineState) -> PipelineState: | |
| import torch | |
| from diffusers import BitsAndBytesConfig | |
| block_state = self.get_block_state(state) | |
| # Map string dtype to torch dtype | |
| def str_to_dtype(dtype_str): | |
| dtype_map = { | |
| "": None, | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "uint8": torch.uint8, | |
| "int8": torch.int8, | |
| "float64": torch.float64, | |
| } | |
| return dtype_map.get(dtype_str, None) | |
| if block_state.quant_type == "bnb_4bit": | |
| config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=block_state.bnb_4bit_quant_type, | |
| bnb_4bit_compute_dtype=str_to_dtype(block_state.bnb_4bit_compute_dtype), | |
| bnb_4bit_use_double_quant=block_state.bnb_4bit_use_double_quant, | |
| llm_int8_skip_modules=block_state.llm_int8_skip_modules, | |
| ) | |
| elif block_state.quant_type == "bnb_8bit": | |
| config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_threshold=block_state.llm_int8_threshold, | |
| llm_int8_has_fp16_weight=block_state.llm_int8_has_fp16_weight, | |
| llm_int8_skip_modules=block_state.llm_int8_skip_modules, | |
| ) | |
| # Output as dict: {"transformer": config} | |
| quantization_config = {block_state.component: config} | |
| block_state.quantization_config = quantization_config | |
| self.set_block_state(state, block_state) | |
| return pipeline, state |