Giter Site home page Giter Site logo

Comments (8)

ruofan-wu avatar ruofan-wu commented on July 27, 2024 1

Thanks for your help! I successfully ran it. Close the issue :)

from hidet.

yaoyaoding avatar yaoyaoding commented on July 27, 2024

Hi @GisellWu,

Any minimal reproducible example to reproduce the error?

from hidet.

ruofan-wu avatar ruofan-wu commented on July 27, 2024
import torch
from transformers import T5Tokenizer, T5Model
import hidet

model_name = 't5-base' 
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name).to(device="cuda:0")

model = torch.compile(model, backend='hidet')

input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
                               padding='max_length', truncation=True, max_length=128).to(device="cuda:0")

outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)

from hidet.

ruofan-wu avatar ruofan-wu commented on July 27, 2024

Furthermore, I added some functions to register_functions.py in order to run through T5Model:

@register_function(torch.abs)
def abs(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
    if out is not None:
        raise NotImplementedError("hidet: does not support torch.abs(..., out=...)")
    return ops.abs(x)

@register_function(torch.log)
def log(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
    if out is not None:
        raise NotImplementedError("hidet: does not support torch.log(..., out=...)")
    return ops.log(x)

@register_function(torch.full_like)
def full_like(input, fill_value, *, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format):
    if layout not in [None, torch.strided]:
        raise NotImplementedError("hidet: does not support torch.full(..., layout=..., ...)")
    if requires_grad and torch.is_grad_enabled():
        warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
    hidet_device: Device = device_from_torch(torch_device=device)
    hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype)
    return ops.full(input.size(), fill_value, dtype=hidet_dtype, device=hidet_device)

@register_function(torch.zeros_like)
def zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format):
    import hidet

    if layout is not None:
        raise NotImplementedError("layout is not None")
    
    size = input.size()
    if len(size) == 1:
        if isinstance(size[0], (list, tuple)):
            size = size[0]
    shape = [int(v) for v in size]

    if dtype is None:
        dtype = torch.get_default_dtype()

    _ = requires_grad

    return hidet.zeros(shape, dtype=dtype_from_torch(dtype), device=device_from_torch(device))

from hidet.

yaoyaoding avatar yaoyaoding commented on July 27, 2024

Hi @GisellWu,

I added the missing operators and fixed some bugs in #322 for T5 model. Could you give a try again?

from hidet.

ruofan-wu avatar ruofan-wu commented on July 27, 2024

Hi @yaoyaoding ,

Sorry to bother you again, Iā€™m trying T5Model with float16. There are some new unsupported functions. Could you please help me fix it?

The error is:

NotImplementedError: The following modules/functions are not supported by hidet yet:
  torch.clamp
  torch.isinf

And the example code is:

import torch
from transformers import T5Tokenizer, T5Model
import hidet

model_name = 't5-base' 
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name, torch_dtype=torch.float16).to(device="cuda:0")

model = torch.compile(model, backend='hidet')

input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
                               padding='max_length', truncation=True, max_length=128).to(device="cuda:0")

outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)

from hidet.

yaoyaoding avatar yaoyaoding commented on July 27, 2024

Hi @GisellWu,

I added the missing operators in #343, could you give it a try? Thanks!

from hidet.

ruofan-wu avatar ruofan-wu commented on July 27, 2024

Hi @yaoyaoding,

That's ok. Appreciate your help again!

from hidet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    šŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. šŸ“ŠšŸ“ˆšŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ā¤ļø Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.