diff --git a/gpu/convert_checkpoint.py b/gpu/convert_checkpoint.py index 797ad1dbb..a96bff5f1 100755 --- a/gpu/convert_checkpoint.py +++ b/gpu/convert_checkpoint.py @@ -1,100 +1,108 @@ -import json -import os -import re -import sys -from pathlib import Path -from typing import Optional -from dataclasses import dataclass -import torch -from einops import rearrange -from safetensors.torch import save_file -import model -from pack_weight import convert_weight_int8_to_int2 - -@torch.inference_mode() -def convert_ts_checkpoint( - *, - input_path: str = "", -) -> None: - - config = model.ModelArgs() - print(f"Model config {config.__dict__}") - - def quant_weight_int8(weight): - s = 1.0 / weight.abs().mean().clamp_(min=1e-5) - new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8) - new_scale = (1.0 / s).to(torch.bfloat16) - return new_weight, new_scale.reshape(1) - - def quant_weight_fp16(weight): - s = 1.0 / weight.abs().mean().clamp_(min=1e-5) - new_weight = (weight * s).round().clamp(-1, 1) / s - return new_weight - - def convert_int8_to_int2(weight): - return convert_weight_int8_to_int2(weight) - - merged_result = torch.load(input_path, map_location="cpu", mmap=True) - int2_result = {} - fp16_result = {} - zero = torch.zeros(1).to(torch.bfloat16) - for key, value in merged_result.items(): - if 'wqkv' in key: - wq = value[:config.dim] - wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim] - wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:] - wq_weight, wa_scale = quant_weight_int8(wq) - wk_weight, wb_scale = quant_weight_int8(wk) - wv_weight, wc_scale = quant_weight_int8(wv) - wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) - wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0) - int2_result[key] = convert_int8_to_int2(wqkv_weight) - int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale - - wq_weight = quant_weight_fp16(wq) - wk_weight = quant_weight_fp16(wk) - wv_weight = quant_weight_fp16(wv) - wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) - fp16_result[key] = wqkv_weight - elif 'w13' in key: - w1 = value[:config.ffn_dim] - w3 = value[config.ffn_dim:] - w1_weight, w1_scale = quant_weight_int8(w1) - w3_weight, w3_scale = quant_weight_int8(w3) - w13_weight = torch.cat([w1_weight, w3_weight], dim=0) - w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0) - int2_result[key] = convert_int8_to_int2(w13_weight) - int2_result[key.replace('weight', 'weight_scale')] = w13_scale - - w1_weight = quant_weight_fp16(w1) - w3_weight = quant_weight_fp16(w3) - w13_weight = torch.cat([w1_weight, w3_weight], dim=0) - fp16_result[key] = w13_weight - elif 'w2' in key or 'wo' in key: - weight, scale = quant_weight_int8(value) - scale = torch.cat([scale, zero, zero, zero], dim=0) - int2_result[key] = convert_int8_to_int2(weight) - int2_result[key.replace('weight', 'weight_scale')] = scale - - weight = quant_weight_fp16(value) - fp16_result[key] = weight - else: - int2_result[key] = value.clone() - fp16_result[key] = value.clone() - - output_dir = os.path.dirname(input_path) - print(f"Saving checkpoint to {output_dir}/model_state_int2.pt") - torch.save(int2_result, f"{output_dir}/model_state_int2.pt") - - print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt") - torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt") - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.') - parser.add_argument('--input', type=str) - - args = parser.parse_args() - convert_ts_checkpoint( - input_path=args.input, - ) +import json +import os +import re +import sys +from pathlib import Path +from typing import Optional +from dataclasses import dataclass +import torch +from einops import rearrange +from safetensors.torch import save_file +import model +from pack_weight import convert_weight_int8_to_int2 + +@torch.inference_mode() +def convert_ts_checkpoint( + *, + input_path: str = "", +) -> None: + + config = model.ModelArgs() + print(f"Model config {config.__dict__}") + + def quant_weight_int8(weight): + s = 1.0 / weight.abs().mean().clamp_(min=1e-5) + new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8) + new_scale = (1.0 / s).to(torch.bfloat16) + return new_weight, new_scale.reshape(1) + + def quant_weight_fp16(weight): + s = 1.0 / weight.abs().mean().clamp_(min=1e-5) + new_weight = (weight * s).round().clamp(-1, 1) / s + return new_weight + + def convert_int8_to_int2(weight): + return convert_weight_int8_to_int2(weight) + + merged_result = torch.load(input_path, map_location="cpu", weights_only=True) + int2_result = {} + fp16_result = {} + zero = torch.zeros(1).to(torch.bfloat16) + for key, value in merged_result.items(): + if 'wqkv' in key: + wq = value[:config.dim] + wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim] + wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:] + wq_weight, wa_scale = quant_weight_int8(wq) + wk_weight, wb_scale = quant_weight_int8(wk) + wv_weight, wc_scale = quant_weight_int8(wv) + wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) + wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0) + int2_result[key] = convert_int8_to_int2(wqkv_weight) + int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale + + wq_weight = quant_weight_fp16(wq) + wk_weight = quant_weight_fp16(wk) + wv_weight = quant_weight_fp16(wv) + wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) + fp16_result[key] = wqkv_weight + elif 'w13' in key: + w1 = value[:config.ffn_dim] + w3 = value[config.ffn_dim:] + w1_weight, w1_scale = quant_weight_int8(w1) + w3_weight, w3_scale = quant_weight_int8(w3) + w13_weight = torch.cat([w1_weight, w3_weight], dim=0) + w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0) + int2_result[key] = convert_int8_to_int2(w13_weight) + int2_result[key.replace('weight', 'weight_scale')] = w13_scale + + w1_weight = quant_weight_fp16(w1) + w3_weight = quant_weight_fp16(w3) + w13_weight = torch.cat([w1_weight, w3_weight], dim=0) + fp16_result[key] = w13_weight + elif 'w2' in key or 'wo' in key: + weight, scale = quant_weight_int8(value) + scale = torch.cat([scale, zero, zero, zero], dim=0) + int2_result[key] = convert_int8_to_int2(weight) + int2_result[key.replace('weight', 'weight_scale')] = scale + + weight = quant_weight_fp16(value) + fp16_result[key] = weight + else: + int2_result[key] = value.clone() + fp16_result[key] = value.clone() + + output_dir = os.path.dirname(os.path.abspath(input_path)) + + # Validate that the output directory is within an expected base directory + base_dir = os.path.abspath(".") + if not os.path.abspath(output_dir).startswith(base_dir): + raise ValueError(f"Output directory '{output_dir}' is outside the base directory '{base_dir}'") + + os.makedirs(output_dir, exist_ok=True) + + print(f"Saving checkpoint to {output_dir}/model_state_int2.pt") + torch.save(int2_result, f"{output_dir}/model_state_int2.pt") + + print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt") + torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt") + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.') + parser.add_argument('--input', type=str) + + args = parser.parse_args() + convert_ts_checkpoint( + input_path=args.input, + ) diff --git a/gpu/generate.py b/gpu/generate.py index 638ed7b37..62b70aac6 100755 --- a/gpu/generate.py +++ b/gpu/generate.py @@ -1,359 +1,359 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -import readline # type: ignore # noqa -import sys -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Iterable, Optional, Tuple, Union - -import fire -import model as fast -import torch -from stats import Stats -from tokenizer import Tokenizer, ChatFormat -import sample_utils -from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, -) - - -@dataclass -class GenArgs: - gen_length: int = 32 - gen_bsz: int = 1 - prompt_length: int = 64 - - use_sampling: bool = False - temperature: float = 0.8 - top_p: float = 0.9 - - -class FastGen: - GRAPH_WARMUPS: int = 1 - tokenizer: Tokenizer - - @staticmethod - def build( - ckpt_dir: str, - gen_args: GenArgs, - device: Union[torch.device, str], - tokenizer_path: Optional[str] = None, - num_layers: int = 13, - use_full_vocab: bool = False, - ) -> "FastGen": - """ - Load a Llama or Code Llama checkpoint and return a new - generator for this model. - """ - start_time = time.time() - - model_args_prefill = fast.ModelArgs(use_kernel=False) - model_args_decode = fast.ModelArgs(use_kernel=True) - tokenizer = Tokenizer("./tokenizer.model") - - torch.set_default_device(device) - torch.set_default_dtype(torch.bfloat16) - - prefill_model = fast.Transformer(model_args_prefill) - decode_model = fast.Transformer(model_args_decode) - - fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") - fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu") - int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") - int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu") - prefill_model.load_state_dict(fp16_checkpoint, strict=True) - decode_model.load_state_dict(int2_checkpoint, strict=True) - - torch.cuda.synchronize() - print(f"loaded model in {time.time() - start_time:.2f} seconds") - start_time = time.time() - - return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer) - - def __init__( - self, - args: GenArgs, - model_args: fast.ModelArgs, - prefill_model: fast.Transformer, - decode_model: fast.Transformer, - tokenizer: Tokenizer, - ): - self.gen_args = args - self.max_seq_length = args.prompt_length + args.gen_length - self.model_args = model_args - # self.model = model - self.prefill_model = prefill_model - self.decode_model = decode_model - self.tokenizer = tokenizer - self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None - self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None - self._cache = None - start_time = time.time() - self._prefill_compile_model = self.compile_prefill() - self._generate_compile_model = self.compile_generate() - print(f"compiled model in {time.time() - start_time:.2f} seconds") - - def compile_prefill(self): - - if self._cache is None: - self._cache = fast.make_cache( - args=self.model_args, - length=self.gen_args.gen_bsz * self.max_seq_length, - ) - - seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] - - bias = AttnBias.from_seqlens( - q_seqlen=seq_lens, - kv_seqlen=seq_lens, - kv_padding=self.max_seq_length, - ) - bias.q_seqinfo.to("cuda") - bias.k_seqinfo.to("cuda") - - tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda() - self._prefill_inputs = (tokens, bias) - - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(s): - _ = self.prefill_model.forward_with_attn_bias( - token_values=self._prefill_inputs[0], - attn_bias=self._prefill_inputs[1], - cache=self._cache, - ) - torch.cuda.current_stream().wait_stream(s) - - self._prefill_cuda_graph = torch.cuda.CUDAGraph() - recording_kwargs = {} - if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: - # In PyTorch 2.1+ and nightlies from late Aug 2023, - # we can do this to maybe avoid watchdog-related crashes - recording_kwargs["capture_error_mode"] = "thread_local" - with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs): - self._prefill_logits = self.prefill_model.forward_with_attn_bias( - token_values=self._prefill_inputs[0], - attn_bias=self._prefill_inputs[1], - cache=self._cache, - ) - - def replay(tokens, seq_lens=None): - self._prefill_inputs[0].copy_(tokens) - if seq_lens is not None: - self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) - - self._prefill_cuda_graph.replay() - torch.cuda.synchronize() - - return self._prefill_logits - - return replay - - def compile_generate(self): - - if self._cache is None: - self._cache = fast.make_cache( - args=self.model_args, - length=self.gen_args.gen_bsz * self.max_seq_length, - ) - - seq_lens = [1 for _ in range(self.gen_args.gen_bsz)] - kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] - - bias = AttnBias.from_seqlens( - q_seqlen=seq_lens, - kv_seqlen=kv_seq_lens, - kv_padding=self.max_seq_length, - ) - bias.q_seqinfo.to("cuda") - bias.k_seqinfo.to("cuda") - - tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda() - self._generate_inputs = (tokens, bias) - - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(s): - _ = self.decode_model.forward_with_attn_bias( - token_values=self._generate_inputs[0], - attn_bias=self._generate_inputs[1], - cache=self._cache, - ) - torch.cuda.current_stream().wait_stream(s) - - self._generate_cuda_graph = torch.cuda.CUDAGraph() - recording_kwargs = {} - if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: - # In PyTorch 2.1+ and nightlies from late Aug 2023, - # we can do this to maybe avoid watchdog-related crashes - recording_kwargs["capture_error_mode"] = "thread_local" - with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs): - self._generate_logits = self.decode_model.forward_with_attn_bias( - token_values=self._generate_inputs[0], - attn_bias=self._generate_inputs[1], - cache=self._cache, - ) - - def replay(tokens, seq_lens): - self._generate_inputs[0].copy_(tokens) - self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) - - self._generate_cuda_graph.replay() - - return self._generate_logits - - return replay - - - @torch.inference_mode() - def generate_all( - self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool - ) -> Tuple[Stats, list[list[int]]]: - bs = len(prompts) - prompt_lens = [len(p) for p in prompts] - padded_prompt_lens = [self.gen_args.prompt_length] * bs - max_prompt_length = max(prompt_lens) - gen_length = self.gen_args.gen_length - max_seq_length = max_prompt_length + gen_length - print(max_prompt_length, gen_length) - - bias = AttnBias.from_seqlens( - q_seqlen=padded_prompt_lens, - kv_seqlen=prompt_lens, - kv_padding=max_seq_length, - ) - bias.q_seqinfo.to("cuda") - bias.k_seqinfo.to("cuda") - - # Input tensors to the cuda graph - kv_seqlen = bias.k_seqinfo.seqlen - prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts] - tokens = torch.IntTensor(sum(prompts, [])).cuda() - out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int) - - stats = Stats() - torch.cuda.synchronize() - stats.phase("prefill" if use_cuda_graphs else "total") - # stats.phase("total") - - output = self._prefill_compile_model(tokens, None) - - logits = output[kv_seqlen - 1, :] - logits = logits.view(bs, self.model_args.vocab_size) - - if use_sampling: - temp = 0.7 - top_p = 0.95 - probs = torch.softmax(logits / temp, dim=-1) - next_token = sample_utils.top_p(probs, top_p) - else: - next_token = torch.argmax(logits, dim=-1) - - next_token = next_token.reshape(bs) - out_tokens[0, :] = next_token - - torch.cuda.synchronize() - stats.phase("decode" if use_cuda_graphs else "total") - - eos_id = self.tokenizer.eot_id - for niter in range(1, gen_length): - kv_seqlen.add_(kv_seqlen < max_seq_length) - output = self._generate_compile_model(next_token, kv_seqlen) - - logits = output.view(bs, self.model_args.vocab_size) - - if use_sampling: - temp = 0.7 - top_p = 0.95 - probs = torch.softmax(logits / temp, dim=-1) - next_token = sample_utils.top_p(probs, top_p) - else: - next_token = torch.argmax(logits, dim=-1) - - next_token = next_token.reshape(bs) - out_tokens[niter, :] = next_token - - if next_token.eq(eos_id).any(): - break - - torch.cuda.synchronize() - stats.end_phase(tokens=niter * bs) - - def trim_answer(prompt_len, tokens): - # print(prompt, tokens) - """Trim the answer to end it on an eos token.""" - tokens = tokens[: max_seq_length - prompt_len] - eos_id = self.tokenizer.eot_id - if eos_id in tokens: - return tokens[: tokens.index(eos_id) + 1] - else: - return tokens - - answers = [ - trim_answer(prompt_len, answer) - for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist()) - ] - return stats, answers - - -def get_prompts(interactive: bool) -> Iterable[list[str]]: - if interactive: - while True: - try: - prompts = input("enter prompt: ").split("\n") - except EOFError: - print("exiting") - sys.exit(0) - yield prompts - else: - yield [ - "Hello, my name is", - ] - - -def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False): - - local_rank = 0 - device = f"cuda:{local_rank}" - torch.cuda.set_device(local_rank) - - g = FastGen.build(ckpt_dir, GenArgs(), device) - - if chat_format: - g.tokenizer = ChatFormat(g.tokenizer) - - for prompts in get_prompts(interactive): - # prompts = [f"{prompt}\n" for prompt in prompts] - if chat_format: - # prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts] - tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts] - else: - tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts] - - print(tokens) - stats, out_tokens = g.generate_all( - tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling, - ) - - for i, prompt in enumerate(prompts): - print(f"> {prompt}") - answer = g.tokenizer.decode(out_tokens[i]) - print(answer) - print("---------------") - - for phase_stats in stats.phases: - print(phase_stats.show()) - - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - - -if __name__ == "__main__": +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import readline # type: ignore # noqa +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Optional, Tuple, Union + +import fire +import model as fast +import torch +from stats import Stats +from tokenizer import Tokenizer, ChatFormat +import sample_utils +from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, +) + + +@dataclass +class GenArgs: + gen_length: int = 32 + gen_bsz: int = 1 + prompt_length: int = 64 + + use_sampling: bool = False + temperature: float = 0.8 + top_p: float = 0.9 + + +class FastGen: + GRAPH_WARMUPS: int = 1 + tokenizer: Tokenizer + + @staticmethod + def build( + ckpt_dir: str, + gen_args: GenArgs, + device: Union[torch.device, str], + tokenizer_path: Optional[str] = None, + num_layers: int = 13, + use_full_vocab: bool = False, + ) -> "FastGen": + """ + Load a Llama or Code Llama checkpoint and return a new + generator for this model. + """ + start_time = time.time() + + model_args_prefill = fast.ModelArgs(use_kernel=False) + model_args_decode = fast.ModelArgs(use_kernel=True) + tokenizer = Tokenizer("./tokenizer.model") + + torch.set_default_device(device) + torch.set_default_dtype(torch.bfloat16) + + prefill_model = fast.Transformer(model_args_prefill) + decode_model = fast.Transformer(model_args_decode) + + fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") + fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu", weights_only=True) + int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") + int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu", weights_only=True) + prefill_model.load_state_dict(fp16_checkpoint, strict=True) + decode_model.load_state_dict(int2_checkpoint, strict=True) + + torch.cuda.synchronize() + print(f"loaded model in {time.time() - start_time:.2f} seconds") + start_time = time.time() + + return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer) + + def __init__( + self, + args: GenArgs, + model_args: fast.ModelArgs, + prefill_model: fast.Transformer, + decode_model: fast.Transformer, + tokenizer: Tokenizer, + ): + self.gen_args = args + self.max_seq_length = args.prompt_length + args.gen_length + self.model_args = model_args + # self.model = model + self.prefill_model = prefill_model + self.decode_model = decode_model + self.tokenizer = tokenizer + self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None + self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None + self._cache = None + start_time = time.time() + self._prefill_compile_model = self.compile_prefill() + self._generate_compile_model = self.compile_generate() + print(f"compiled model in {time.time() - start_time:.2f} seconds") + + def compile_prefill(self): + + if self._cache is None: + self._cache = fast.make_cache( + args=self.model_args, + length=self.gen_args.gen_bsz * self.max_seq_length, + ) + + seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] + + bias = AttnBias.from_seqlens( + q_seqlen=seq_lens, + kv_seqlen=seq_lens, + kv_padding=self.max_seq_length, + ) + bias.q_seqinfo.to("cuda") + bias.k_seqinfo.to("cuda") + + tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda() + self._prefill_inputs = (tokens, bias) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(s): + _ = self.prefill_model.forward_with_attn_bias( + token_values=self._prefill_inputs[0], + attn_bias=self._prefill_inputs[1], + cache=self._cache, + ) + torch.cuda.current_stream().wait_stream(s) + + self._prefill_cuda_graph = torch.cuda.CUDAGraph() + recording_kwargs = {} + if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: + # In PyTorch 2.1+ and nightlies from late Aug 2023, + # we can do this to maybe avoid watchdog-related crashes + recording_kwargs["capture_error_mode"] = "thread_local" + with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs): + self._prefill_logits = self.prefill_model.forward_with_attn_bias( + token_values=self._prefill_inputs[0], + attn_bias=self._prefill_inputs[1], + cache=self._cache, + ) + + def replay(tokens, seq_lens=None): + self._prefill_inputs[0].copy_(tokens) + if seq_lens is not None: + self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) + + self._prefill_cuda_graph.replay() + torch.cuda.synchronize() + + return self._prefill_logits + + return replay + + def compile_generate(self): + + if self._cache is None: + self._cache = fast.make_cache( + args=self.model_args, + length=self.gen_args.gen_bsz * self.max_seq_length, + ) + + seq_lens = [1 for _ in range(self.gen_args.gen_bsz)] + kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] + + bias = AttnBias.from_seqlens( + q_seqlen=seq_lens, + kv_seqlen=kv_seq_lens, + kv_padding=self.max_seq_length, + ) + bias.q_seqinfo.to("cuda") + bias.k_seqinfo.to("cuda") + + tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda() + self._generate_inputs = (tokens, bias) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(s): + _ = self.decode_model.forward_with_attn_bias( + token_values=self._generate_inputs[0], + attn_bias=self._generate_inputs[1], + cache=self._cache, + ) + torch.cuda.current_stream().wait_stream(s) + + self._generate_cuda_graph = torch.cuda.CUDAGraph() + recording_kwargs = {} + if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: + # In PyTorch 2.1+ and nightlies from late Aug 2023, + # we can do this to maybe avoid watchdog-related crashes + recording_kwargs["capture_error_mode"] = "thread_local" + with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs): + self._generate_logits = self.decode_model.forward_with_attn_bias( + token_values=self._generate_inputs[0], + attn_bias=self._generate_inputs[1], + cache=self._cache, + ) + + def replay(tokens, seq_lens): + self._generate_inputs[0].copy_(tokens) + self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) + + self._generate_cuda_graph.replay() + + return self._generate_logits + + return replay + + + @torch.inference_mode() + def generate_all( + self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool + ) -> Tuple[Stats, list[list[int]]]: + bs = len(prompts) + prompt_lens = [len(p) for p in prompts] + padded_prompt_lens = [self.gen_args.prompt_length] * bs + max_prompt_length = max(prompt_lens) + gen_length = self.gen_args.gen_length + max_seq_length = max_prompt_length + gen_length + print(max_prompt_length, gen_length) + + bias = AttnBias.from_seqlens( + q_seqlen=padded_prompt_lens, + kv_seqlen=prompt_lens, + kv_padding=max_seq_length, + ) + bias.q_seqinfo.to("cuda") + bias.k_seqinfo.to("cuda") + + # Input tensors to the cuda graph + kv_seqlen = bias.k_seqinfo.seqlen + prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts] + tokens = torch.IntTensor(sum(prompts, [])).cuda() + out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int) + + stats = Stats() + torch.cuda.synchronize() + stats.phase("prefill" if use_cuda_graphs else "total") + # stats.phase("total") + + output = self._prefill_compile_model(tokens, None) + + logits = output[kv_seqlen - 1, :] + logits = logits.view(bs, self.model_args.vocab_size) + + if use_sampling: + temp = 0.7 + top_p = 0.95 + probs = torch.softmax(logits / temp, dim=-1) + next_token = sample_utils.top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + + next_token = next_token.reshape(bs) + out_tokens[0, :] = next_token + + torch.cuda.synchronize() + stats.phase("decode" if use_cuda_graphs else "total") + + eos_id = self.tokenizer.eot_id + for niter in range(1, gen_length): + kv_seqlen.add_(kv_seqlen < max_seq_length) + output = self._generate_compile_model(next_token, kv_seqlen) + + logits = output.view(bs, self.model_args.vocab_size) + + if use_sampling: + temp = 0.7 + top_p = 0.95 + probs = torch.softmax(logits / temp, dim=-1) + next_token = sample_utils.top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + + next_token = next_token.reshape(bs) + out_tokens[niter, :] = next_token + + if next_token.eq(eos_id).any(): + break + + torch.cuda.synchronize() + stats.end_phase(tokens=niter * bs) + + def trim_answer(prompt_len, tokens): + # print(prompt, tokens) + """Trim the answer to end it on an eos token.""" + tokens = tokens[: max_seq_length - prompt_len] + eos_id = self.tokenizer.eot_id + if eos_id in tokens: + return tokens[: tokens.index(eos_id) + 1] + else: + return tokens + + answers = [ + trim_answer(prompt_len, answer) + for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist()) + ] + return stats, answers + + +def get_prompts(interactive: bool) -> Iterable[list[str]]: + if interactive: + while True: + try: + prompts = input("enter prompt: ").split("\n") + except EOFError: + print("exiting") + sys.exit(0) + yield prompts + else: + yield [ + "Hello, my name is", + ] + + +def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False): + + local_rank = 0 + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + g = FastGen.build(ckpt_dir, GenArgs(), device) + + if chat_format: + g.tokenizer = ChatFormat(g.tokenizer) + + for prompts in get_prompts(interactive): + # prompts = [f"{prompt}\n" for prompt in prompts] + if chat_format: + # prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts] + tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts] + else: + tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts] + + print(tokens) + stats, out_tokens = g.generate_all( + tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling, + ) + + for i, prompt in enumerate(prompts): + print(f"> {prompt}") + answer = g.tokenizer.decode(out_tokens[i]) + print(answer) + print("---------------") + + for phase_stats in stats.phases: + print(phase_stats.show()) + + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == "__main__": fire.Fire(main) \ No newline at end of file