Epoch [1/3]
KeyError Traceback (most recent call last)
File :21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:937, in build_triton_ir(fn, signature, specialization, constants)
936 try:
--> 937 generator.visit(fn.parse())
938 except Exception as e:
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:183, in CodeGenerator.visit_Module(self, node)
182 def visit_Module(self, node):
--> 183 ast.NodeVisitor.generic_visit(self, node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:379, in NodeVisitor.generic_visit(self, node)
378 if isinstance(item, AST):
--> 379 self.visit(item)
380 elif isinstance(value, AST):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:252, in CodeGenerator.visit_FunctionDef(self, node)
251 # visit function body
--> 252 has_ret = self.visit_compound_statement(node.body)
253 # finalize function
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts)
176 for stmt in stmts:
--> 177 self.last_ret_type = self.visit(stmt)
178 if isinstance(stmt, ast.Return):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:678, in CodeGenerator.visit_For(self, node)
677 self.scf_stack.append(node)
--> 678 self.visit_compound_statement(node.body)
679 self.scf_stack.pop()
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts)
176 for stmt in stmts:
--> 177 self.last_ret_type = self.visit(stmt)
178 if isinstance(stmt, ast.Return):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:319, in CodeGenerator.visit_AugAssign(self, node)
318 assign = ast.Assign(targets=[node.target], value=rhs)
--> 319 self.visit(assign)
320 return self.get_value(name)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:301, in CodeGenerator.visit_Assign(self, node)
300 names = _names[0]
--> 301 values = self.visit(node.value)
302 if not isinstance(names, tuple):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:339, in CodeGenerator.visit_BinOp(self, node)
338 lhs = self.visit(node.left)
--> 339 rhs = self.visit(node.right)
340 fn = {
341 ast.Add: 'add',
342 ast.Sub: 'sub',
(...)
352 ast.BitXor: 'xor',
353 }[type(node.op)]
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/ast.py:371, in NodeVisitor.visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:797, in CodeGenerator.visit_Call(self, node)
795 if (hasattr(fn, 'self') and self.is_triton_tensor(fn.self))
796 or impl.is_builtin(fn):
--> 797 return fn(*args, _builder=self.builder, **kws)
798 if fn in self.builtins.values():
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/impl/base.py:22, in builtin..wrapper(*args, **kwargs)
18 raise ValueError(
19 "Did you forget to add @triton.jit ? "
20 "(_builder
argument must be provided outside of JIT functions.)"
21 )
---> 22 return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'
The above exception was the direct cause of the following exception:
CompilationError Traceback (most recent call last)
Cell In[15], line 1
----> 1 teacher_train(T_model, cfg, train_loader, test_loader)
Cell In[14], line 39, in teacher_train(model, config, train_loader, test_loader)
37 mask = mask.to(config.device)
38 labels = labels.to(config.device)
---> 39 outputs = model(ids, mask)
40 model.zero_grad()
41 loss = F.cross_entropy(outputs, labels)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[12], line 12, in BERT_Model.forward(self, context, mask)
11 def forward(self, context, mask):
---> 12 outputs = self.bert(context, attention_mask=mask)
13 pooled = outputs[1]
14 out = self.fc(pooled)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:608, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs)
605 first_col_mask[:, 0] = True
606 subset_mask = masked_tokens_mask | first_col_mask
--> 608 encoder_outputs = self.encoder(
609 embedding_output,
610 attention_mask,
611 output_all_encoded_layers=output_all_encoded_layers,
612 subset_mask=subset_mask)
614 if masked_tokens_mask is None:
615 sequence_output = encoder_outputs[-1]
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:446, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask)
444 if subset_mask is None:
445 for layer_module in self.layer:
--> 446 hidden_states = layer_module(hidden_states,
447 cu_seqlens,
448 seqlen,
449 None,
450 indices,
451 attn_mask=attention_mask,
452 bias=alibi_attn_mask)
453 if output_all_encoded_layers:
454 all_encoder_layers.append(hidden_states)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:327, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias)
305 def forward(
306 self,
307 hidden_states: torch.Tensor,
(...)
313 bias: Optional[torch.Tensor] = None,
314 ) -> torch.Tensor:
315 """Forward pass for a BERT layer, including both attention and MLP.
316
317 Args:
(...)
325 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
326 """
--> 327 attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
328 subset_idx, indices, attn_mask, bias)
329 layer_output = self.mlp(attention_output)
330 return layer_output
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:240, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias)
218 def forward(
219 self,
220 input_tensor: torch.Tensor,
(...)
226 bias: Optional[torch.Tensor] = None,
227 ) -> torch.Tensor:
228 """Forward pass for scaled self-attention without padding.
229
230 Arguments:
(...)
238 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
239 """
--> 240 self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
241 attn_mask, bias)
242 if subset_idx is not None:
243 return self.output(index_first_axis(self_output, subset_idx),
244 index_first_axis(input_tensor, subset_idx))
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/bert_layers.py:181, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias)
179 bias_dtype = bias.dtype
180 bias = bias.to(torch.float16)
--> 181 attention = flash_attn_qkvpacked_func(qkv, bias)
182 attention = attention.to(orig_dtype)
183 bias = bias.to(bias_dtype)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
1019 if qkv.stride(-1) != 1:
1020 qkv = qkv.contiguous()
-> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward(
1022 qkv[:, :, 0],
1023 qkv[:, :, 1],
1024 qkv[:, :, 2],
1025 bias=bias,
1026 causal=causal,
1027 softmax_scale=softmax_scale)
1028 ctx.save_for_backward(qkv, o, lse, bias)
1029 ctx.causal = causal
File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3ef4a608677312175eb6f8143d/flash_attn_triton.py:826, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
823 # BLOCK = 128
824 # num_warps = 4 if d <= 64 else 8
825 grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
--> 826 _fwd_kernel[grid]( # type: ignore
827 q,
828 k,
829 v,
830 bias,
831 o,
832 lse,
833 tmp,
834 softmax_scale,
835 q.stride(0),
836 q.stride(2),
837 q.stride(1),
838 k.stride(0),
839 k.stride(2),
840 k.stride(1),
841 v.stride(0),
842 v.stride(2),
843 v.stride(1),
844 *bias_strides,
845 o.stride(0),
846 o.stride(2),
847 o.stride(1),
848 nheads,
849 seqlen_q,
850 seqlen_k,
851 seqlen_q_rounded,
852 d,
853 seqlen_q // 32,
854 seqlen_k // 32, # key for triton cache (limit number of compilations)
855 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
856 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
857 bias_type,
858 causal,
859 BLOCK_HEADDIM,
860 # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
861 # num_warps=num_warps,
862 # num_stages=1,
863 )
864 return o, lse, softmax_scale
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/runtime/autotuner.py:90, in Autotuner.run(self, *args, **kwargs)
88 if config.pre_hook is not None:
89 config.pre_hook(self.nargs)
---> 90 return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs)
197 for v, heur in self.values.items():
198 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 199 return self.fn.run(*args, **kwargs)
File :41, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:1621, in compile(fn, **kwargs)
1619 next_module = parse(path)
1620 else:
-> 1621 next_module = compile(module)
1622 fn_cache_manager.put(next_module, f"{name}.{ir}")
1623 if os.path.exists(path):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:1550, in compile..(src)
1545 extern_libs = kwargs.get("extern_libs", dict())
1546 # build compilation stages
1547 stages = {
1548 "ast": (lambda path: fn, None),
1549 "ttir": (lambda path: parse_mlir_module(path, context),
-> 1550 lambda src: ast_to_ttir(src, signature, configs[0], constants)),
1551 "ttgir": (lambda path: parse_mlir_module(path, context),
1552 lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
1553 "llir": (lambda path: Path(path).read_text(),
1554 lambda src: ttgir_to_llir(src, extern_libs, capability)),
1555 "ptx": (lambda path: Path(path).read_text(),
1556 lambda src: llir_to_ptx(src, capability)),
1557 "cubin": (lambda path: Path(path).read_bytes(),
1558 lambda src: ptx_to_cubin(src, capability))
1559 }
1560 # find out the signature of the function
1561 if isinstance(fn, triton.runtime.JITFunction):
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:962, in ast_to_ttir(fn, signature, specialization, constants)
961 def ast_to_ttir(fn, signature, specialization, constants):
--> 962 mod, _ = build_triton_ir(fn, signature, specialization, constants)
963 return optimize_triton_ir(mod)
File ~/anaconda3/envs/pytorch_python38/lib/python3.8/site-packages/triton/compiler.py:942, in build_triton_ir(fn, signature, specialization, constants)
940 if node is None or isinstance(e, (NotImplementedError, CompilationError)):
941 raise e
--> 942 raise CompilationError(fn.src, node) from e
943 ret = generator.module
944 # module takes ownership of the context