Comments (4)
If you are just asking for a catchable exception then #1077 should close this. We would like to eventually allow int64 and other 8 byte types to work with scatter, but that is more involved.
from mlx.
Thank you Awni. some observations:
- crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?
- scatter ops can use cpu device for int64 and uint64 ?
- adding a note about supported dtypes and devices cpu, gpu in mlx.core.array.at is helpful .
- are there any other ops which are not supported on gpu and run on cpu ?
zeros = mx.zeros(shape, dtype=values.dtype)
zeros = zeros.at[indices].add(values)
i tried this and it does not work as add does not take device
kw_arg:
if zeros.dtype in [mx.int64, mx.uint64] and mx.get_default_device == mx.DeviceType.gpu :
device = mx.Device(type=mx.DeviceType.cpu)
zeros = zeros.at[indices].add(values, device=device)
else:
zeros = zeros.at[indices].add(values)
It would be helpful if mlx can fallback to cpu for scatter ops which are not supported on gpu or allow device kw_arg for all scatter ops.
Additional ops which are impacted by this bug:
- mx.cumsum
- mx.cumprod
- mx.diag
from mlx.
crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?
I improved the message in #1077. The problem is with the values.
scatter ops can use cpu device for int64 and uint64 ?
We prefer not to silently route to the CPU for ops without a GPU back-end. You can do this in the API by changing the default stream to the CPU before calling the scatter when the dytpe is int64/uint64.
are there any other ops which are not supported on gpu and run on cpu ?
Just a few. FFT and some of the lapack ops (QR / Inverse). Metal support for FFT is coming soon in #981 .
i tried this and it does not work as add does not take device kw_arg:
You can use a context manager. For most free ops stream
kwarg also works. E.g.
v = mx.array([1, 2, 3])
u = mx.array([1, 2])
idx = mx.array([0, 1])
with mx.stream(mx.cpu):
out = v.at[idx].add(u)
from mlx.
Thank you @awni for the fix.
from mlx.
Related Issues (20)
- [BUG] mx.radians & mx.degrees - unexpected behavior when the input is not an array
- [Feature] Build flag to make safetensor and GGUF dependencies optional
- [Feature] something like `mlx.scipy.stats` HOT 2
- [BUG] Passing `axis=None` into `argpartition` causes `TypeError` HOT 1
- [BUG] AttributeError in mlx.core.conj and mlx.core.conjugate functions HOT 3
- Optimization Plans for Conv2D CPU Execution HOT 3
- [BUG] mlx gets stuck with high-dimensional array on Linux HOT 1
- Implement trace analogical to numpy.trace HOT 1
- [BUG] Wrong slice of a 4D array assigned to with GPU HOT 1
- problem of using mlx package HOT 2
- Question about supporting slices of the type a[:, [0]] HOT 2
- Difference in training convergence between PyTorch & MLX HOT 2
- [BUG] mlx.core.topk throws segmentation fault for large dimension HOT 1
- [BUG] JIT compile mode does not work with LoRA
- [Feature] dlpack device HOT 5
- [BUG] Compiled mx.eval(model.state) raises “Attempting to eval an array without a primitive” with mlx.optimizers.Adam HOT 4
- [BUG] compile + checkpoint segfaults HOT 1
- [BUG] Wrong result for sliced matmul on GPU HOT 1
- I'm asking for help with the following error: HOT 1
- [Enhancement] be able to override MLX_METAL_VERSION when running cmake
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlx.