Comments (8)
Thanks for your help! I successfully ran it. Close the issue :)
from hidet.
Hi @GisellWu,
Any minimal reproducible example to reproduce the error?
from hidet.
import torch
from transformers import T5Tokenizer, T5Model
import hidet
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name).to(device="cuda:0")
model = torch.compile(model, backend='hidet')
input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
padding='max_length', truncation=True, max_length=128).to(device="cuda:0")
outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)
from hidet.
Furthermore, I added some functions to register_functions.py in order to run through T5Model:
@register_function(torch.abs)
def abs(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.abs(..., out=...)")
return ops.abs(x)
@register_function(torch.log)
def log(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.log(..., out=...)")
return ops.log(x)
@register_function(torch.full_like)
def full_like(input, fill_value, *, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format):
if layout not in [None, torch.strided]:
raise NotImplementedError("hidet: does not support torch.full(..., layout=..., ...)")
if requires_grad and torch.is_grad_enabled():
warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
hidet_device: Device = device_from_torch(torch_device=device)
hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype)
return ops.full(input.size(), fill_value, dtype=hidet_dtype, device=hidet_device)
@register_function(torch.zeros_like)
def zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format):
import hidet
if layout is not None:
raise NotImplementedError("layout is not None")
size = input.size()
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = [int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()
_ = requires_grad
return hidet.zeros(shape, dtype=dtype_from_torch(dtype), device=device_from_torch(device))
from hidet.
Hi @GisellWu,
I added the missing operators and fixed some bugs in #322 for T5 model. Could you give a try again?
from hidet.
Hi @yaoyaoding ,
Sorry to bother you again, Iām trying T5Model with float16. There are some new unsupported functions. Could you please help me fix it?
The error is:
NotImplementedError: The following modules/functions are not supported by hidet yet:
torch.clamp
torch.isinf
And the example code is:
import torch
from transformers import T5Tokenizer, T5Model
import hidet
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name, torch_dtype=torch.float16).to(device="cuda:0")
model = torch.compile(model, backend='hidet')
input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
padding='max_length', truncation=True, max_length=128).to(device="cuda:0")
outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)
from hidet.
Hi @GisellWu,
I added the missing operators in #343, could you give it a try? Thanks!
from hidet.
Hi @yaoyaoding,
That's ok. Appreciate your help again!
from hidet.
Related Issues (20)
- [Bug] Outputs of torch.abs abnormally mismatch on GPU and CPU when applying commutative law of multiplication HOT 15
- Is there any way for users to inspect the connections between cuda kernels compiled from operators? HOT 8
- [Feature] No torch.sqrt support in Hidet ? HOT 2
- Will hidet launch all cuda kernel on the same cudaStream? HOT 2
- Some generated cuda kernel's input's shape is 0
- Is `hidet_launch` called by any other runtimes to inference? HOT 1
- nope
- [FEATURE] Meet an undefined operator when compiling NASNet HOT 3
- [Bug] Failed to build task HOT 1
- [Bug] How do you handle graph breaks coming from Dynamo? HOT 6
- [Bug] Lambda and numpy() cannot coexist in a script HOT 1
- [Bug] ops.concat does not work the same as torch.cat
- [Bug] Pickle.loads have python deserialization attacks HOT 1
- Google Colab: KeyError Primitive function cuda_i64_to_f16 has already registered
- [Bug] hidet.ops.conv2d fails to compile for CUDA fp16 HOT 2
- Google Colab: OSError: Can not find library in the following directory
- [Bug] Repeated include statements within the same source.cu/source.cc file
- [Bug] hidet.ops.tan cannot work HOT 1
- [Bug] Need to port Publish to PyPl workflow to ARC cluster HOT 5
- [CI] Test Publish on wheel or not
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
š Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ššš
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ā¤ļø Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hidet.