Giter Site home page Giter Site logo

Comments (30)

karpathy avatar karpathy commented on June 12, 2024 3

Actually now I'm not sure what caused it. I thought maybe it's our change from a few minutes ago to use cuBLAS tf32, but that can't be it?

Btw I think it is non-deterministic. When I re-make / re-run it works, then it doesn't. My favorite kinds of issues with C.

from llm.c.

karpathy avatar karpathy commented on June 12, 2024 2

@msharmavikram the positional encoder code has been there for a very long time

Looking at the diff, most of it should be harmless (e.g. the zero_grad and backward functions that I accidentally added are actually not called at all).

The core issue is I think the addition of softmax_forward_kernel5, and how that interacts with us doing inference, when we truncate the time dimension in an effort to be faster. We only check kernels on a fixed B,T, and here I am meddling with T and changing it dynamically, and I think that's messing up the code. Not 100% sure how it goes wrong, potentially some memory corruption.

For example when, in the kernel, we try to do:

const float* x = inp + idx * T;

I don't think this actually gets the row of x like we desire, because T here is now truncated to e.g. 4 (out of 1024). So I think we end up reading all wrong elements. Anyway I'm still looking but I'm very suspicious of this dynamic resizing now that I look at it, and I think we can't afford to do it.

from llm.c.

karpathy avatar karpathy commented on June 12, 2024

How strange - I can reproduce. Must have been the last commit? I thought I checked it before push

from llm.c.

whao avatar whao commented on June 12, 2024

How strange - I can reproduce. Must have been the last commit? I thought I checked it before push

Yes, I tested c02bae2ebc684a2e068c0dc59be00ff43167b44d

from llm.c.

whao avatar whao commented on June 12, 2024

Actually now I'm not sure what caused it. I thought maybe it's our change from a few minutes ago to use cuBLAS tf32, but that can't be it?

Btw I think it is non-deterministic. When I re-make / re-run it works, then it doesn't. My favorite kinds of issues with C.

It can be reproduced on my device every time :P, I am looking into it but I'm not an expert on CUDA.

from llm.c.

whao avatar whao commented on June 12, 2024

Actually now I'm not sure what caused it. I thought maybe it's our change from a few minutes ago to use cuBLAS tf32, but that can't be it?

Btw I think it is non-deterministic. When I re-make / re-run it works, then it doesn't. My favorite kinds of issues with C.

I am not sure if it is caused by the lack of VRAM when doing the inference during the training. But I have changed it to a smaller batch size of 2 which only used 5g VRAM but the issue persisted.

from llm.c.

whao avatar whao commented on June 12, 2024

If set enable_tf32 = 0, then we have:

[System]
Device 0: NVIDIA GeForce RTX 3060
enable_tf32: 0
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 149
val dataset num_batches: 16
batch size: 2
sequence length: 1024
val_num_batches: 10
num_activations: 1228318720
val loss 4.531809
step 0: train loss 4.270288 (took 101.887508 ms)
step 1: train loss 4.465426 (took 102.043178 ms)
step 2: train loss 4.410182 (took 101.841108 ms)
step 3: train loss 4.402787 (took 101.596036 ms)
step 4: train loss 4.618854 (took 101.319191 ms)
step 5: train loss 4.350821 (took 101.289429 ms)
step 6: train loss 4.460876 (took 101.300518 ms)
step 7: train loss 4.229775 (took 101.248242 ms)
step 8: train loss 4.071756 (took 101.816512 ms)
step 9: train loss 4.014830 (took 101.224944 ms)
val loss 4.531809
step 10: train loss 4.232022 (took 101.445115 ms)
step 11: train loss 4.226580 (took 101.695865 ms)
step 12: train loss 4.243248 (took 101.441674 ms)
step 13: train loss 4.106736 (took 101.794824 ms)
step 14: train loss 4.258279 (took 102.083281 ms)
step 15: train loss 4.156654 (took 101.854741 ms)
step 16: train loss 4.180723 (took 101.771757 ms)
step 17: train loss 4.073589 (took 101.580366 ms)
step 18: train loss 4.234052 (took 101.577380 ms)
step 19: train loss 4.207106 (took 101.533642 ms)
val loss 4.531809
[cuBLAS ERROR]: 13 train_gpt2.cu 482

It is the same error when I tried earlier today.

from llm.c.

Bing1002 avatar Bing1002 commented on June 12, 2024

I got the same issue at the end.
"""
[CUDA ERROR] at file train_gpt2.cu:1238:
an illegal memory access was encountered
"""

from llm.c.

tnorman42 avatar tnorman42 commented on June 12, 2024

Also seeing this on NVIDIA GeForce RTX 2080 Ti.

The exact error message is non-deterministic, but it always seems to fail after step 19.

[cuBLAS ERROR]: 13 train_gpt2.cu 509

or

[CUDA ERROR] at file train_gpt2.cu:1238:
an illegal memory access was encountered

from llm.c.

chanwaijye avatar chanwaijye commented on June 12, 2024

using GTX 1050 on WSL:

Device 0: NVIDIA GeForce GTX 1050
enable_tf32: 0
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 74
val dataset num_batches: 8
batch size: 4
sequence length: 1024
val_num_batches: 10
num_activations: 2456637440
val loss 4.517293
step 0: train loss 4.367856 (took 15930.875039 ms)
step 1: train loss 4.406481 (took 15807.003816 ms)
step 2: train loss 4.484837 (took 15777.008516 ms)
step 3: train loss 4.345325 (took 16503.159295 ms)
step 4: train loss 4.043289 (took 15918.101903 ms)
step 5: train loss 4.229304 (took 15822.736054 ms)
step 6: train loss 4.174998 (took 15762.027166 ms)
step 7: train loss 4.207466 (took 15933.593094 ms)
step 8: train loss 4.127153 (took 16174.404151 ms)
step 9: train loss 4.220581 (took 18367.605210 ms)
val loss 4.517293
step 10: train loss 4.345143 (took 15999.325256 ms)
step 11: train loss 4.245717 (took 15862.750973 ms)
step 12: train loss 4.160386 (took 15866.520096 ms)
step 13: train loss 3.989360 (took 15898.156320 ms)
step 14: train loss 4.305948 (took 15954.338483 ms)
step 15: train loss 4.340422 (took 16023.643041 ms)
step 16: train loss 4.304480 (took 15899.189344 ms)
step 17: train loss 4.424025 (took 15839.624771 ms)
step 18: train loss 4.314513 (took 15854.631572 ms)
step 19: train loss 4.287313 (took 15870.214153 ms)
val loss 4.517293
[CUDA ERROR] at file train_gpt2.cu:1238:
an illegal memory access was encountered

from llm.c.

karpathy avatar karpathy commented on June 12, 2024

There's definitely a bug. I'll have to stare at this a bit longer tomorrow and do a dissection to find the commit where it was introduced. It must have been pretty recent, probably somewhere today.

from llm.c.

tnorman42 avatar tnorman42 commented on June 12, 2024

Bisecting, it starts with 6b49ed1 for me.

from llm.c.

msharmavikram avatar msharmavikram commented on June 12, 2024

Alright. This is a weird issue. For this kind of error, one should technically use a compute-sanitizer. Surprisingly, I do not see the error when I enable compute-sanitizer! The most logical reason for this is that we have made poor calls for synchronization boundaries. Please recall that CUDA calls are async in nature.

Anyway, I cannot debug further without a codebase expert (spent ~20 minutes in this codebase), and here is what I have found so far. You should be able to reproduce this debug.

# no bug observed with compute-sanitizer
compute-sanitizer --racecheck-detect-level info --racecheck-indirect-barrier-dependency --racecheck-report all --track-stream-ordered-races all  --report-api-errors  all  ./train_gpt2cu 
========= COMPUTE-SANITIZER
[System]
Device 0: NVIDIA A100-PCIE-40GB
enable_tf32: 1
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 74
val dataset num_batches: 8
batch size: 4
sequence length: 1024
val_num_batches: 10
num_activations: 2456637440
val loss 4.517179
step 0: train loss 4.367631 (took 2848.275642 ms)
step 1: train loss 4.406341 (took 2846.808552 ms)
step 2: train loss 4.484756 (took 2848.409444 ms)
step 3: train loss 4.345182 (took 2846.061281 ms)
step 4: train loss 4.043440 (took 2845.134784 ms)
step 5: train loss 4.229531 (took 2850.073134 ms)
step 6: train loss 4.175078 (took 2848.740164 ms)
step 7: train loss 4.207684 (took 2851.417664 ms)
step 8: train loss 4.127494 (took 2848.887400 ms)
step 9: train loss 4.220500 (took 2849.752472 ms)
val loss 4.517179
step 10: train loss 4.345251 (took 2846.314426 ms)
step 11: train loss 4.245913 (took 2848.427818 ms)
step 12: train loss 4.160710 (took 2850.893041 ms)
step 13: train loss 3.989707 (took 2846.624978 ms)
step 14: train loss 4.305970 (took 2848.236049 ms)
step 15: train loss 4.340496 (took 2847.727786 ms)
step 16: train loss 4.304414 (took 2849.400782 ms)
step 17: train loss 4.424054 (took 2849.699372 ms)
step 18: train loss 4.314544 (took 2845.855294 ms)
step 19: train loss 4.287184 (took 2848.324695 ms)
val loss 4.517179
generated: 50256 45 522 1743 18681 5908 661 338 3348 390 2185 983 785 9124 398 276 379 262 2695 262 14249 352 257 14249 338 11 618 340 1049 267 1049 264 31899 198 198 198 198 11 290 625 481 4966 1863 351 6821 13 2647 198 198 198 198 198 198 13 28323 477 2107 290 11 314 1111 532 14678 284 
step 20: train loss 5.026133 (took 2848.520823 ms)
step 21: train loss 4.860029 (took 2847.845306 ms)
step 22: train loss 4.924436 (took 2849.745809 ms)
step 23: train loss 4.795475 (took 2844.809975 ms)
step 24: train loss 4.962370 (took 2848.215110 ms)
step 25: train loss 5.044362 (took 2847.104718 ms)
step 26: train loss 5.073789 (took 2848.870728 ms)
step 27: train loss 4.999546 (took 2849.887515 ms)
step 28: train loss 4.965069 (took 2848.875928 ms)
step 29: train loss 4.977201 (took 2847.829586 ms)

So, I had to explicitly force the synchronization points to determine which kernel was causing this error (to find the culprit)

506 void encoder_forward(float* out,
507              int* inp, float* wte, float* wpe,
508              int B, int T, int C) {
509  const int N = B * T * C;
510  const int block_size = 256;
511  const int grid_size = CEIL_DIV(N, block_size);
512  cudaCheck(cudaDeviceSynchronize());
513  encoder_forward_kernel2<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);
514
515  cudaCheck(cudaDeviceSynchronize());  ///<<-- forcing sync boundaries
516  cudaCheck(cudaGetLastError());
}
...
val loss 4.517179
[CUDA ERROR] at file train_gpt2.cu:515:   //<<<----sync boundaries helping to find the line number where this is happening.
an illegal memory access was encountered

Now I have no idea what is happening inside the encoder_forward_kernel2 causing this issue. The error can only happen if we are trying to access memory that is not available (we are observing a READ GPU memory fault).

How do i know? use sudo dmesg -wH and you will see something of this sort -

[  +2.321212] NVRM: Xid (PCI:0000:01:00): 31, pid=32177, name=train_gpt2cu, Ch 00000008, intr 00000000. MMU Fault: ENGINE GRAPHICS GPCCLIENT_T1_4 faulted @ 0x7f64_6dfff000. Fault is of type FAULT_PDE ACCESS_TYPE_VIRT_READ

An important observation here is ACCESS_TYPE_VIRT_READ

Given this. there are three possible culprits in encoder_forward_kernel2.

        int ix = inp[b * T + t];
...
        float* wte_ix = wte + ix * C + c;
        float* wpe_tc = wpe + t * C + c;

I can force each of the lines to fixed index values and see which of these lines causes the error. Doing so I found the error is when we read wte_ix to compute the out_btc and it occurs when the N = 6144. I am not an expert on this code and I am not sure what should be boundary conditions for this to ensure correctness.

The correct way this code should have been written requires a boundary check before the output write and also while reading *wte_ix and *wpe_tc. Since this boundary check is not present, ofcourse we will hit this error.

Also, this begs the question - if we are observing perf benefit over pytorch, then the two codes may not be functionally equivalent if the boundary conditions checks are not implemented correctly. We may be getting correct results after all but it might be all due to a fluke. So, I would be cautious about the perf claim against pytorch at the current stage (request to add a disclaimer).

from llm.c.

karpathy avatar karpathy commented on June 12, 2024

The dynamic resizing happens with this code

// once in a while do model inference to print generated text
if (step > 0 && step % 20 == 0) {
    gen_tokens[0] = GPT2_EOT; // the GPT-2 EOT token kicks off the generation
    for (int t = 1; t < gen_max_length; t++) {
        // note that inference is wasteful here because
        // for each t, we re-compute all activations between 0 and t
        // leaving this alone because you want separate code for inference anyway
        // the inference here is just for sanity checking purposes
        int t4 = (t + 3) & ~3; // clever way to round up to multiple of 4
        gpt2_forward(&model, gen_tokens, NULL, 1, t4);
        float* probs = model.acts.probs + (t-1) * model.config.vocab_size;
        float coin = random_f32(&rng_state);
        // move probs back to CPU and sample
        cudaCheck(cudaMemcpy(cpu_probs, probs, model.config.vocab_size * sizeof(float), cudaMemcpyDeviceToHost));
        int next_token = sample_mult(cpu_probs, model.config.vocab_size, coin);
        gen_tokens[t] = next_token;
    }
    printf("generated: ");
    for (int t = 0; t < gen_max_length; t++) {
        printf("%d ", gen_tokens[t]);
    }
    printf("\n");
}

which only runs every 20 iteration. which is exactly where we see the crash.

from llm.c.

msharmavikram avatar msharmavikram commented on June 12, 2024

Interesting. Then I agree with you that the issue is somewhere in dynamic resizing and how it interacts with the new softmax_forward_kernel5.

from llm.c.

ngc92 avatar ngc92 commented on June 12, 2024

could you test #122

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

I'm seeing both these errors.

Either the error 13

[cuBLAS ERROR]: 13 train_gpt2.cu (some line)
or

   Device 0: Tesla V100-SXM2-16GB
   ....
   step 19: train loss 4.287314 (took 102.180264 ms)
   val loss 4.517294
   [CUDA ERROR] at file train_gpt2.cu:1238:
   an illegal memory access was encountered

on NVIDIA V100 or T10 (on amazon AWS machines on Debian 11).

from llm.c.

g8392 avatar g8392 commented on June 12, 2024

I'm also seeing a similar error

Error Message

/train_gpt2cu
[System]
Device 0: NVIDIA GeForce RTX 3090
enable_tf32: 1
[GPT-2]
step 17: train loss 4.424054 (took 62.469384 ms)
step 18: train loss 4.314544 (took 61.996948 ms)
step 19: train loss 4.287184 (took 66.194802 ms)
val loss 4.517179
[CUDA ERROR] at file train_gpt2.cu:1238:
an illegal memory access was encountered

sys-info

nvidia-smi
Sun Apr 14 14:41:08 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:27:00.0  On |                  N/A |
| 36%   51C    P3             78W /  370W |    3169MiB /  24576MiB |     13%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
Pop-22.04

It's also giving me some compilation warnings about an fread call not being used.

from llm.c.

ngc92 avatar ngc92 commented on June 12, 2024

@zocterminal @g8392 Is that with #122? Or the master branch?

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

@ngc92 on master. What do I need to do to check #122 (if that is of interest)?

from llm.c.

ngc92 avatar ngc92 commented on June 12, 2024

@zocterminal just check out that branch (it's on my fork, so you might need to add this as a remote) , recompile, and see if it works

from llm.c.

g8392 avatar g8392 commented on June 12, 2024

#122 Seems to fix it for me @ngc92 , it also looks it has longer times for each step 62 ->108 ms:

step 396: train loss 4.406373 (took 127.762113 ms)
step 397: train loss 4.334729 (took 125.671292 ms)
step 398: train loss 4.368728 (took 123.486678 ms)
step 399: train loss 4.287299 (took 122.625955 ms)
val loss 4.517293
generated: 50256 1671 29 198 198 4023 1378 73 14822 4364 18855 8952 13 40346 13 785 14 10531 14 2919 14 1129 14 3642 46927 12 22540 498 12 325 16025 14 198 198 40 2911 345 1100 428 13 314 2911 345 1839 470 475 836 470 1028 502 13 314 2911 345 389 4684 284 2245 1642 262 976 10135 13 314 
step 400: train loss 4.155364 (took 108.857453 ms)
``

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

@ngc92 on a quick check #122 seems to fix it (I will have to double check, so take this result with grains of salt).

from llm.c.

ngc92 avatar ngc92 commented on June 12, 2024

I'm a bit surprised that this would slow down anything. Can you check what happens if you compile with -DNDEBUG?
Maybe it is the CPU-side validation of the indices, in which case we could move that inside the kernel, which is memory-bound anyway.

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

Tentatively: I think this commit is the one that breaks it.
6b49ed1

@ngc92 @karpathy

maybe someone can verify:

Crashes:

git checkout 6b49ed1c0bace91295bc7c3ad9d6c33c6552bdfd
make train_gpt2cu
./train_gpt2cu

Works:

git checkout 4f75e645730fc424dd75e252fb21413894ab78f7
make train_gpt2cu
./train_gpt2cu

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

@ngc92 now that I (maybe) narrowed down, let me play with your fix for a bit less tentative answer. BRB.

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

@ngc92 @karpathy

I could reproduce this here rather consistently (see also #114 (comment))

I would say #122 fixes it.

Initial download after cloning master:

admin@ip-10-0-12-239:~$ cd llm/
admin@ip-10-0-12-239:~/llm$ python prepro_tinyshakespeare.py
python train_gpt2.py
Downloading https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt to data/tiny_shakespeare.txt...
data/tiny_shakespeare.txt: 1.06MiB [00:00, 56.0MiB/s]
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
using device: cuda
loading weights from pretrained gpt: gpt2
loading cached tokens in data/tiny_shakespeare_val.bin
wrote gpt2_124M.bin
wrote gpt2_124M_debug_state.bin
iteration 0, loss: 5.270007610321045, time: 2258.069ms
iteration 1, loss: 4.059666156768799, time: 45.970ms
iteration 2, loss: 3.3750417232513428, time: 44.514ms
iteration 3, loss: 2.8006718158721924, time: 44.466ms
iteration 4, loss: 2.315293073654175, time: 44.393ms
iteration 5, loss: 1.8489176034927368, time: 44.503ms
iteration 6, loss: 1.3945326805114746, time: 44.388ms
iteration 7, loss: 0.9989923238754272, time: 44.444ms
iteration 8, loss: 0.6240252256393433, time: 44.486ms
iteration 9, loss: 0.3764869272708893, time: 44.448ms
final 20 iters avg: 265.968ms
<|endoftext|>One year ago today:
This is the first week since we last spoke.
---------------

Build master and run:

admin@ip-10-0-12-239:~/llm$ make train_gpt2cu
NICE Compiling with OpenMP support
nvcc -O3 --use_fast_math train_gpt2.cu -lcublas -lcublasLt -o train_gpt2cu
admin@ip-10-0-12-239:~/llm$ ./train_gpt2cu
[System]
Device 0: Tesla V100-SXM2-16GB
enable_tf32: 0
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 74
val dataset num_batches: 8
batch size: 4
sequence length: 1024
val_num_batches: 10
num_activations: 2456637440
val loss 4.517294
step 0: train loss 4.367857 (took 101.732135 ms)
step 1: train loss 4.406480 (took 101.166434 ms)
step 2: train loss 4.484838 (took 101.761969 ms)
step 3: train loss 4.345327 (took 101.626840 ms)
step 4: train loss 4.043288 (took 101.825034 ms)
step 5: train loss 4.229304 (took 101.839962 ms)
step 6: train loss 4.174997 (took 101.894603 ms)
step 7: train loss 4.207467 (took 101.432225 ms)
step 8: train loss 4.127151 (took 101.388704 ms)
step 9: train loss 4.220581 (took 101.145741 ms)
val loss 4.517294
step 10: train loss 4.345143 (took 101.384587 ms)
step 11: train loss 4.245718 (took 101.350847 ms)
step 12: train loss 4.160385 (took 101.298673 ms)
step 13: train loss 3.989359 (took 101.644972 ms)
step 14: train loss 4.305947 (took 101.964139 ms)
step 15: train loss 4.340423 (took 102.008248 ms)
step 16: train loss 4.304481 (took 102.241103 ms)
step 17: train loss 4.424024 (took 102.095094 ms)
step 18: train loss 4.314511 (took 101.715553 ms)
step 19: train loss 4.287314 (took 101.871424 ms)
val loss 4.517294
**[CUDA ERROR] at file train_gpt2.cu:1238:**
an illegal memory access was encountered

Patch #122 and run again:

admin@ip-10-0-12-239:~/llm$ patch -p 1 < ~/122.diff
patching file train_gpt2.cu
admin@ip-10-0-12-239:~/llm$ make train_gpt2cu
NICE Compiling with OpenMP support
nvcc -O3 --use_fast_math train_gpt2.cu -lcublas -lcublasLt -o train_gpt2cu
./train_gpt2cu
admin@ip-10-0-12-239:~/llm$ ./train_gpt2cu
[System]
Device 0: Tesla V100-SXM2-16GB
enable_tf32: 0
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 74
val dataset num_batches: 8
batch size: 4
sequence length: 1024
val_num_batches: 10
num_activations: 2456637440
val loss 4.517294
step 0: train loss 4.367857 (took 101.881251 ms)
step 1: train loss 4.406480 (took 102.162624 ms)
step 2: train loss 4.484838 (took 101.812695 ms)
step 3: train loss 4.345327 (took 101.290578 ms)
step 4: train loss 4.043288 (took 101.676153 ms)
step 5: train loss 4.229304 (took 101.762146 ms)
step 6: train loss 4.174997 (took 101.498859 ms)
step 7: train loss 4.207467 (took 101.694050 ms)
step 8: train loss 4.127151 (took 102.217656 ms)
step 9: train loss 4.220581 (took 102.076835 ms)
val loss 4.517294
step 10: train loss 4.345143 (took 101.823438 ms)
step 11: train loss 4.245718 (took 102.059243 ms)
step 12: train loss 4.160385 (took 102.194263 ms)
step 13: train loss 3.989359 (took 102.070963 ms)
step 14: train loss 4.305947 (took 102.134566 ms)
step 15: train loss 4.340423 (took 101.949639 ms)
step 16: train loss 4.304481 (took 102.192408 ms)
step 17: train loss 4.424024 (took 102.201196 ms)
step 18: train loss 4.314511 (took 102.247464 ms)
step 19: train loss 4.287314 (took 102.193246 ms)
val loss 4.517294
**generated: 50256 50 695 379 26549 532 772 517 621 749 517 621 477 621 673 468 587 290 1657 290 1714 351 198 198 198 198 647 423 2219 286 351 257 2
7737 69 275 2040 319 1 290 484 588 4046 2219 2160 2250 13 6275 290 4423 510 284 477 1744 25 35501 318 2199 645 1 314 716 477 20030 389**
step 20: train loss 4.976575 (took 101.099377 ms)
step 21: train loss 4.859861 (took 101.296487 ms)
step 22: train loss 4.928136 (took 101.406619 ms)
step 23: train loss 4.791417 (took 101.825221 ms)
step 24: train loss 4.955164 (took 102.122429 ms)
step 25: train loss 5.047175 (took 101.991752 ms)
step 26: train loss 5.072573 (took 101.687222 ms)
step 27: train loss 4.979808 (took 101.338288 ms)
step 28: train loss 4.955441 (took 101.770186 ms)
step 29: train loss 4.981244 (took 101.637839 ms)
val loss 5.117234
step 30: train loss 4.852224 (took 101.474148 ms)
step 31: train loss 4.839732 (took 102.083349 ms)
step 32: train loss 4.867337 (took 102.310648 ms)
step 33: train loss 5.023737 (took 102.078247 ms)
step 34: train loss 4.912455 (took 101.996438 ms)
step 35: train loss 4.958746 (took 102.094551 ms)
step 36: train loss 5.013899 (took 101.647376 ms)
step 37: train loss 4.949145 (took 101.965694 ms)
step 38: train loss 4.959915 (took 101.653358 ms)
step 39: train loss 4.851017 (took 101.897501 ms)
val loss 5.117234
generated: 50256 21448 27709 7911 284 262 14089 13 17762 13 1318 318 407 691 779 340 783 4961 670 513 286 1222 1103 3797 284 477 477 355 4844 11 3
52 11 355 257 649 5860 11 717 510 1597 1303 526 389 19979 318 477 82 290 13 329 1598 477 13 410 11 362 11 284 262 19974 477 11 477 11
step 40: train loss 5.007130 (took 101.367506 ms)
admin@ip-10-0-12-239:~/llm$
admin@ip-10-0-12-239:~/llm$

Also, no noticeable speed changes.

Good job @ngc92 !

from llm.c.

g8392 avatar g8392 commented on June 12, 2024

The -DNDEBUG doesn't seem to affect it @ngc92 , this is what happens when I do the same @zocterminal:

./train_gpt2cu
.....
step 0: train loss 4.367631 (took 59.288083 ms)
step 1: train loss 4.406341 (took 60.256468 ms)
step 2: train loss 4.484756 (took 57.820963 ms)
.....
[CUDA ERROR] at file train_gpt2.cu:1238:
an illegal memory access was encountered
patch train_gpt2.cu -p1 < 122.diff

patching file train_gpt2.cu
make train_gpt2cu ; ./train_gpt2cu

NICE Compiling with OpenMP support
nvcc -O3 --use_fast_math train_gpt2.cu -lcublas -lcublasLt -o train_gpt2cu
train_gpt2.cu(640): warning #1650-D: result of call is not used
      fread(model_header, sizeof(int), 256, model_file);
      ^
Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
train_gpt2.cu(689): warning #1650-D: result of call is not used
      fread(params_memory_cpu, sizeof(float), num_parameters, model_file);
........
train_gpt2.cu: In function 'void dataloader_next_batch(DataLoader*)'
train_gpt2.cu:937:6: warning: ignoring return value of 'size_t fread(void*, size_t, size_t, FILE*) declared with attribute 'warn_unused_result' [-Wunused-result]
  937 |     fread(loader->batch, sizeof(int), B*T+1, loader->tokens_file);
      |     ~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
[GPT-2]
.....
step 0: train loss 4.367857 (took 109.131925 ms)
step 1: train loss 4.406483 (took 108.166702 ms)
step 2: train loss 4.484839 (took 108.053300 ms)
step 3: train loss 4.345326 (took 108.359766 ms)
step 4: train loss 4.043288 (took 109.111375 ms)
step 5: train loss 4.229303 (took 108.113061 ms)

from llm.c.

zocterminal avatar zocterminal commented on June 12, 2024

@g8392 @ngc92

what if you change line 1205 in master just like this (this fixed it for me also)?

int gen_tokens[gen_max_length] = { 0, };

from llm.c.

g8392 avatar g8392 commented on June 12, 2024

I haven't been able to pinpoint the source of the difference but by setting up everything again and using train_gpt2.cu from #122 all seems to work and speed is the same too.


step 37: train loss 4.950854 (took 58.518316 ms)
step 38: train loss 4.960339 (took 58.977451 ms)
step 39: train loss 4.833032 (took 59.156414 ms)
val loss 5.123018
generated: 50256 21129 16932 29406 499 13 11023 261 24514 13 572 13 572 1092 532 477 572 3635 393 597 393 597 597 1092 477 505 597 393 4844 286 357 270 468 450 298 23461 485 366 645 2081 705 1222 477 31812 597 584 290 510 11 477 2227 374 11 290 11 287 477 262 308 16298 477 11 477 287 
step 40: train loss 5.008294 (took 58.417834 ms)

from llm.c.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.