From 46904d5a2cdf3579ae09d499a2324db316388148 Mon Sep 17 00:00:00 2001 From: jahatef Date: Wed, 19 Jun 2024 21:15:05 +0000 Subject: [PATCH] setup --- configs/local_setup.yml | 4 ++++ megatron/model/rwkv/v6/rwkv.py | 1 + 2 files changed, 5 insertions(+) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index d031a2ad8..3bf17ca3d 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -22,6 +22,10 @@ "load": "checkpoints", "checkpoint_validation_with_forward_pass": False, + + # "launcher": "openmpi", + #"deepspeed_mpi": true, + "tensorboard_dir": "tensorboard", "log_dir": "logs", "use_wandb": True, diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index eaeec4ad7..b2a261842 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -223,6 +223,7 @@ def forward(self, x): H = self.neox_args.num_attention_heads r, k, v, g, w = self.jit_func(x) + print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) return self.jit_func_2(x, g)